diff --git a/CLAUDE.md b/CLAUDE.md index ace6f553..456213b7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -103,6 +103,30 @@ from cryptography.hazmat.primitives.asymmetric import x25519 This keeps code readable and avoids mental overhead of tracking renamed imports. +### Type Annotations + +**Never quote type annotations when `from __future__ import annotations` is present.** With future annotations, all annotations are already lazy strings. Adding quotes is redundant and noisy. + +Bad: +```python +from __future__ import annotations + +def create(cls) -> "Store": + ... +``` + +Good: +```python +from __future__ import annotations + +def create(cls) -> Store: + ... +``` + +The only valid use of quoted annotations is in files that do NOT have `from __future__ import annotations` and need a forward reference. Prefer adding the future import instead. + +**Prefer narrow domain types over raw builtins.** Use `Bytes32`, `Bytes33`, `Bytes52` etc. instead of `bytes` in signatures. Spec code should never accept or return `bytes` when a more specific type exists. + ### Module-Level Constants Use docstrings (not comments) to document module-level constants. Place the docstring immediately after the assignment. diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index 4f67712b..30693e24 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -157,7 +157,7 @@ def set_max_slot_default(self) -> Self: if isinstance(step, BlockStep): max_slot_value = max(max_slot_value, step.block.slot) elif isinstance(step, AttestationStep): - max_slot_value = max(max_slot_value, step.attestation.message.slot) + max_slot_value = max(max_slot_value, step.attestation.data.slot) self.max_slot = max_slot_value @@ -219,9 +219,8 @@ def make_fixture(self) -> Self: # # The Store is the node's local view of the chain. # It starts from a trusted anchor (usually genesis). - store = Store.get_forkchoice_store( - anchor_state=self.anchor_state, - anchor_block=self.anchor_block, + store = self.anchor_state.to_forkchoice_store( + self.anchor_block, validator_id=DEFAULT_VALIDATOR_ID, ) @@ -400,7 +399,7 @@ def _build_block_from_spec( working_store = working_store.on_gossip_attestation( SignedAttestation( validator_id=attestation.validator_id, - message=attestation.data, + data=attestation.data, signature=signature, ), scheme=LEAN_ENV_TO_SCHEMES[self.lean_env], diff --git a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py index 0bbc614b..32331602 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py +++ b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py @@ -23,7 +23,7 @@ ) from lean_spec.subspecs.containers.checkpoint import Checkpoint from lean_spec.subspecs.containers.state.state import State -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.ssz import hash_tree_root from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof from lean_spec.types import Bytes32 @@ -42,7 +42,7 @@ def _create_dummy_aggregated_proof(validator_ids: list[ValidatorIndex]) -> Aggre so it will fail verification. """ return AggregatedSignatureProof( - participants=AggregationBits.from_validator_indices(validator_ids), + participants=AggregationBits.from_validator_indices(ValidatorIndices(data=validator_ids)), proof_data=ByteListMiB(data=b"\x00" * 32), # Invalid proof bytes ) @@ -216,7 +216,9 @@ def _build_block_from_spec( data_root = attestation_data.data_root_bytes() # Create aggregated attestation claiming validator_ids as participants - aggregation_bits = AggregationBits.from_validator_indices(invalid_spec.validator_ids) + aggregation_bits = AggregationBits.from_validator_indices( + ValidatorIndices(data=invalid_spec.validator_ids) + ) invalid_aggregated = AggregatedAttestation( aggregation_bits=aggregation_bits, data=attestation_data, @@ -238,7 +240,9 @@ def _build_block_from_spec( ] # Create valid aggregated proof from actual signers valid_proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(invalid_spec.signer_ids), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=invalid_spec.signer_ids) + ), public_keys=signer_public_keys, signatures=signer_signatures, message=data_root, diff --git a/src/lean_spec/__main__.py b/src/lean_spec/__main__.py index b98a6ca2..fc666070 100644 --- a/src/lean_spec/__main__.py +++ b/src/lean_spec/__main__.py @@ -36,7 +36,6 @@ from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.validator import SubnetId -from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.genesis import GenesisConfig from lean_spec.subspecs.networking.client import LiveNetworkEventSource from lean_spec.subspecs.networking.enr import ENR @@ -271,7 +270,7 @@ async def _init_from_checkpoint( # The store treats this as the new "genesis" for fork choice purposes. # All blocks before the checkpoint are effectively pruned. validator_id = validator_registry.primary_index() if validator_registry else None - store = Store.get_forkchoice_store(state, anchor_block, validator_id) + store = state.to_forkchoice_store(anchor_block, validator_id) logger.info( "Initialized from checkpoint at slot %d (finalized=%s)", state.slot, diff --git a/src/lean_spec/subspecs/chain/clock.py b/src/lean_spec/subspecs/chain/clock.py index 0ef5502c..e8a2f44a 100644 --- a/src/lean_spec/subspecs/chain/clock.py +++ b/src/lean_spec/subspecs/chain/clock.py @@ -9,9 +9,9 @@ """ import asyncio +from collections.abc import Callable from dataclasses import dataclass from time import time as wall_time -from typing import Callable from lean_spec.subspecs.containers.slot import Slot from lean_spec.types import Uint64 diff --git a/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py b/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py index 3b12ea31..097bb527 100644 --- a/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py +++ b/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py @@ -19,15 +19,12 @@ class AggregationBits(BaseBitlist): LIMIT = int(VALIDATOR_REGISTRY_LIMIT) @classmethod - def from_validator_indices( - cls, indices: "ValidatorIndices | list[ValidatorIndex]" - ) -> AggregationBits: + def from_validator_indices(cls, indices: ValidatorIndices) -> AggregationBits: """ Construct aggregation bits from a set of validator indices. Args: - indices: Validator indices to set in the bitlist. Accepts either - a ValidatorIndices collection or a plain list of ValidatorIndex. + indices: Validator indices to set in the bitlist. Returns: AggregationBits with the corresponding indices set to True. @@ -36,8 +33,7 @@ def from_validator_indices( AssertionError: If no indices are provided. AssertionError: If any index is outside the supported LIMIT. """ - # Extract list from ValidatorIndices if needed - index_list = indices.data if isinstance(indices, ValidatorIndices) else indices + index_list = indices.data # Require at least one validator for a valid aggregation. if not index_list: @@ -57,7 +53,7 @@ def from_validator_indices( # - False elsewhere. return cls(data=[Boolean(i in ids) for i in range(max_id + 1)]) - def to_validator_indices(self) -> "ValidatorIndices": + def to_validator_indices(self) -> ValidatorIndices: """ Extract all validator indices encoded in these aggregation bits. diff --git a/src/lean_spec/subspecs/containers/attestation/attestation.py b/src/lean_spec/subspecs/containers/attestation/attestation.py index 092e355b..89efa4ad 100644 --- a/src/lean_spec/subspecs/containers/attestation/attestation.py +++ b/src/lean_spec/subspecs/containers/attestation/attestation.py @@ -16,7 +16,7 @@ from collections import defaultdict from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.ssz import hash_tree_root from lean_spec.types import Bytes32, Container @@ -56,15 +56,9 @@ class Attestation(Container): """The attestation data produced by the validator.""" -class SignedAttestation(Container): +class SignedAttestation(Attestation): """Validator attestation bundled with its signature.""" - validator_id: ValidatorIndex - """The index of the validator making the attestation.""" - - message: AttestationData - """The attestation message signed by the validator.""" - signature: Signature """Signature aggregation produced by the leanVM (SNARKs in the future).""" @@ -103,7 +97,9 @@ def aggregate_by_data( return [ cls( - aggregation_bits=AggregationBits.from_validator_indices(validator_ids), + aggregation_bits=AggregationBits.from_validator_indices( + ValidatorIndices(data=validator_ids) + ), data=data, ) for data, validator_ids in data_to_validator_ids.items() diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index 57d6e2bb..4b54b08d 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -2,8 +2,10 @@ from __future__ import annotations -from typing import AbstractSet, Collection, Iterable +from typing import TYPE_CHECKING, AbstractSet, Collection, Iterable +from lean_spec.subspecs.chain.clock import Interval +from lean_spec.subspecs.chain.config import INTERVALS_PER_SLOT from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.subspecs.xmss.aggregation import ( AggregatedSignatureProof, @@ -24,7 +26,7 @@ from ..checkpoint import Checkpoint from ..config import Config from ..slot import Slot -from ..validator import ValidatorIndex +from ..validator import ValidatorIndex, ValidatorIndices from .types import ( HistoricalBlockHashes, JustificationRoots, @@ -33,6 +35,9 @@ Validators, ) +if TYPE_CHECKING: + from lean_spec.subspecs.forkchoice import Store + class State(Container): """The main consensus state object.""" @@ -73,7 +78,7 @@ class State(Container): """A bitlist of validators who participated in justifications.""" @classmethod - def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "State": + def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> State: """ Generate a genesis state with empty history and proper initial values. @@ -117,7 +122,72 @@ def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "Stat justifications_validators=JustificationValidators(data=[]), ) - def process_slots(self, target_slot: Slot) -> "State": + def to_forkchoice_store( + self, + anchor_block: Block, + validator_id: ValidatorIndex | None, + ) -> Store: + """ + Initialize a forkchoice store from this state and an anchor block. + + The anchor block and this state form the starting point for fork choice. + Both are treated as justified and finalized. + + Args: + anchor_block: A trusted block (e.g. genesis or checkpoint). + validator_id: Index of the validator running this store. + + Returns: + A new Store instance, ready to accept blocks and attestations. + + Raises: + AssertionError: + If the anchor block's state root does not match the hash + of this state. + """ + from lean_spec.subspecs.forkchoice import Store + + # Compute the SSZ root of this state. + # + # This is the canonical hash that should appear in the block's state root. + computed_state_root = hash_tree_root(self) + + # Check that the block actually points to this state. + # + # If this fails, the caller has supplied inconsistent inputs. + assert anchor_block.state_root == computed_state_root, ( + "Anchor block state root must match anchor state hash" + ) + + # Compute the SSZ root of the anchor block itself. + # + # This root will be used as: + # - the key in the blocks/states maps, + # - the initial head, + # - the root of the initial checkpoints. + anchor_root = hash_tree_root(anchor_block) + + # Read the slot at which the anchor block was proposed. + anchor_slot = anchor_block.slot + + # Initialize checkpoints from this state. + # + # We explicitly set the root to the anchor block root. + # The state internally might have zero-hash checkpoints (if genesis), + # but the Store must treat the anchor block as the justified/finalized point. + return Store( + time=Interval(anchor_slot * INTERVALS_PER_SLOT), + config=self.config, + head=anchor_root, + safe_target=anchor_root, + latest_justified=self.latest_justified.model_copy(update={"root": anchor_root}), + latest_finalized=self.latest_finalized.model_copy(update={"root": anchor_root}), + blocks={anchor_root: anchor_block}, + states={anchor_root: self}, + validator_id=validator_id, + ) + + def process_slots(self, target_slot: Slot) -> State: """ Advance the state through empty slots up to, but not including, target_slot. @@ -192,7 +262,7 @@ def process_slots(self, target_slot: Slot) -> "State": # Reached the target slot. Return the advanced state. return state - def process_block_header(self, block: Block) -> "State": + def process_block_header(self, block: Block) -> State: """ Validate the block header and update header-linked state. @@ -341,7 +411,7 @@ def process_block_header(self, block: Block) -> "State": } ) - def process_block(self, block: Block) -> "State": + def process_block(self, block: Block) -> State: """ Apply full block processing including header and body. @@ -368,7 +438,7 @@ def process_block(self, block: Block) -> "State": def process_attestations( self, attestations: Iterable[AggregatedAttestation], - ) -> "State": + ) -> State: """ Apply attestations and update justification/finalization according to the Lean Consensus 3SF-mini rules. @@ -617,7 +687,7 @@ def process_attestations( } ) - def state_transition(self, block: Block, valid_signatures: bool = True) -> "State": + def state_transition(self, block: Block, valid_signatures: bool = True) -> State: """ Apply the complete state transition function for a block. @@ -670,7 +740,7 @@ def build_block( available_attestations: Iterable[Attestation] | None = None, known_block_roots: AbstractSet[Bytes32] | None = None, aggregated_payloads: dict[SignatureKey, list[AggregatedSignatureProof]] | None = None, - ) -> tuple[Block, "State", list[AggregatedAttestation], list[AggregatedSignatureProof]]: + ) -> tuple[Block, State, list[AggregatedAttestation], list[AggregatedSignatureProof]]: """ Build a valid block on top of this state. @@ -858,7 +928,9 @@ def aggregate_gossip_signatures( # The aggregation combines multiple XMSS signatures into a single # compact proof that can verify all participants signed the message. if gossip_ids: - participants = AggregationBits.from_validator_indices(gossip_ids) + participants = AggregationBits.from_validator_indices( + ValidatorIndices(data=gossip_ids) + ) proof = AggregatedSignatureProof.aggregate( participants=participants, public_keys=gossip_keys, diff --git a/src/lean_spec/subspecs/containers/state/types.py b/src/lean_spec/subspecs/containers/state/types.py index db53bbee..631e6648 100644 --- a/src/lean_spec/subspecs/containers/state/types.py +++ b/src/lean_spec/subspecs/containers/state/types.py @@ -70,7 +70,7 @@ def with_justified( finalized_slot: Slot, target_slot: Slot, value: Boolean, - ) -> "JustifiedSlots": + ) -> JustifiedSlots: """ Return a new bitfield with the justification status updated. diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index 827b93db..58b424c6 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -154,76 +154,6 @@ class Store(Container): Used for recursive signature aggregation when building blocks. """ - @classmethod - def get_forkchoice_store( - cls, - anchor_state: State, - anchor_block: Block, - validator_id: ValidatorIndex | None, - ) -> "Store": - """ - Initialize forkchoice store from an anchor state and block. - - The anchor block and its state form the starting point for fork choice. - We treat this anchor as both justified and finalized. - - Args: - anchor_state: The state corresponding to the anchor block. - anchor_block: A trusted block (e.g. genesis or checkpoint). - validator_id: Index of the validator running this store. - - Returns: - A new Store instance, ready to accept blocks and attestations. - - Raises: - AssertionError: - If the anchor block's state root does not match the hash - of the provided state. - """ - # Compute the SSZ root of the given state. - # - # This is the canonical hash that should appear in the block's state root. - computed_state_root = hash_tree_root(anchor_state) - - # Check that the block actually points to this state. - # - # If this fails, the caller has supplied inconsistent inputs. - assert anchor_block.state_root == computed_state_root, ( - "Anchor block state root must match anchor state hash" - ) - - # Compute the SSZ root of the anchor block itself. - # - # This root will be used as: - # - the key in the blocks/states maps, - # - the initial head, - # - the root of the initial checkpoints. - anchor_root = hash_tree_root(anchor_block) - - # Read the slot at which the anchor block was proposed. - anchor_slot = anchor_block.slot - - # Build an initial checkpoint using the anchor block. - # - # Both the root and the slot come directly from the anchor. - # Initialize checkpoints from the anchor state - # - # We explicitly set the root to the anchor block root. - # The anchor state internally might have zero-hash checkpoints (if genesis), - # but the Store must treat the anchor block as the justified/finalized point. - - return cls( - time=Interval(anchor_slot * INTERVALS_PER_SLOT), - config=anchor_state.config, - head=anchor_root, - safe_target=anchor_root, - latest_justified=anchor_state.latest_justified.model_copy(update={"root": anchor_root}), - latest_finalized=anchor_state.latest_finalized.model_copy(update={"root": anchor_root}), - blocks={anchor_root: anchor_block}, - states={anchor_root: anchor_state}, - validator_id=validator_id, - ) - def prune_stale_attestation_data(self) -> "Store": """ Remove attestation data that can no longer influence fork choice. @@ -376,7 +306,7 @@ def on_gossip_attestation( AssertionError: If signature verification fails. """ validator_id = signed_attestation.validator_id - attestation_data = signed_attestation.message + attestation_data = signed_attestation.data signature = signed_attestation.signature # Validate the attestation first so unknown blocks are rejected cleanly @@ -651,8 +581,7 @@ def on_block( # Store proposer signature for future lookup if it belongs to the same committee # as the current validator. if self.validator_id is not None: - proposer_validator_id = proposer_attestation.validator_id - proposer_subnet_id = proposer_validator_id.compute_subnet_id( + proposer_subnet_id = proposer_attestation.validator_id.compute_subnet_id( ATTESTATION_COMMITTEE_COUNT ) current_validator_subnet_id = self.validator_id.compute_subnet_id( @@ -1010,7 +939,7 @@ def aggregate_committee_signatures(self) -> tuple["Store", list[SignedAggregated # Each SignatureKey contains (validator_id, data_root) # We look up the full AttestationData from attestation_data_by_root attestation_list: list[Attestation] = [] - for sig_key in self.gossip_signatures.keys(): + for sig_key in self.gossip_signatures: data_root = sig_key.data_root attestation_data = self.attestation_data_by_root.get(data_root) if attestation_data is not None: diff --git a/src/lean_spec/subspecs/networking/config.py b/src/lean_spec/subspecs/networking/config.py index 51e0cf36..8750dc3e 100644 --- a/src/lean_spec/subspecs/networking/config.py +++ b/src/lean_spec/subspecs/networking/config.py @@ -2,8 +2,6 @@ from typing import Final -from lean_spec.types.byte_arrays import Bytes1 - from .types import DomainType # --- Request/Response Limits --- @@ -32,13 +30,13 @@ # --- Gossip Message Domains --- -MESSAGE_DOMAIN_INVALID_SNAPPY: Final[DomainType] = Bytes1(b"\x00") +MESSAGE_DOMAIN_INVALID_SNAPPY: Final[DomainType] = DomainType(b"\x00") """1-byte domain for gossip message-id isolation of invalid snappy messages. Per Ethereum spec, prepended to the message hash when decompression fails. """ -MESSAGE_DOMAIN_VALID_SNAPPY: Final[DomainType] = Bytes1(b"\x01") +MESSAGE_DOMAIN_VALID_SNAPPY: Final[DomainType] = DomainType(b"\x01") """1-byte domain for gossip message-id isolation of valid snappy messages. Per Ethereum spec, prepended to the message hash when decompression succeeds. diff --git a/src/lean_spec/subspecs/networking/discovery/crypto.py b/src/lean_spec/subspecs/networking/discovery/crypto.py index 62521446..e9fcf29a 100644 --- a/src/lean_spec/subspecs/networking/discovery/crypto.py +++ b/src/lean_spec/subspecs/networking/discovery/crypto.py @@ -155,11 +155,6 @@ def aes_ctr_encrypt(key: Bytes16, iv: Bytes16, plaintext: bytes) -> bytes: Returns: Ciphertext of same length as plaintext. """ - if len(key) != AES_KEY_SIZE: - raise ValueError(f"Key must be {AES_KEY_SIZE} bytes, got {len(key)}") - if len(iv) != CTR_IV_SIZE: - raise ValueError(f"IV must be {CTR_IV_SIZE} bytes, got {len(iv)}") - cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) encryptor = cipher.encryptor() return encryptor.update(plaintext) + encryptor.finalize() @@ -198,11 +193,6 @@ def aes_gcm_encrypt(key: Bytes16, nonce: Bytes12, plaintext: bytes, aad: bytes) Returns: Ciphertext with 16-byte authentication tag appended. """ - if len(key) != AES_KEY_SIZE: - raise ValueError(f"Key must be {AES_KEY_SIZE} bytes, got {len(key)}") - if len(nonce) != GCM_NONCE_SIZE: - raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - aesgcm = AESGCM(key) return aesgcm.encrypt(nonce, plaintext, aad) @@ -225,16 +215,11 @@ def aes_gcm_decrypt(key: Bytes16, nonce: Bytes12, ciphertext: bytes, aad: bytes) Raises: cryptography.exceptions.InvalidTag: If authentication fails. """ - if len(key) != AES_KEY_SIZE: - raise ValueError(f"Key must be {AES_KEY_SIZE} bytes, got {len(key)}") - if len(nonce) != GCM_NONCE_SIZE: - raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - aesgcm = AESGCM(key) return aesgcm.decrypt(nonce, ciphertext, aad) -def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: bytes) -> Bytes33: +def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: Bytes33 | Bytes65) -> Bytes33: """ Perform secp256k1 ECDH key agreement. @@ -252,9 +237,6 @@ def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: bytes) -> Bytes33: Returns: 33-byte shared secret (compressed point from ECDH). """ - if len(private_key_bytes) != 32: - raise ValueError(f"Private key must be 32 bytes, got {len(private_key_bytes)}") - scalar = int.from_bytes(private_key_bytes, "big") point = _decompress_pubkey(public_key_bytes) result = _point_mul(scalar, point) @@ -287,7 +269,7 @@ def generate_secp256k1_keypair() -> tuple[Bytes32, Bytes33]: return Bytes32(private_bytes), Bytes33(public_bytes) -def pubkey_to_uncompressed(public_key_bytes: bytes) -> Bytes65: +def pubkey_to_uncompressed(public_key_bytes: Bytes33 | Bytes65) -> Bytes65: """ Convert any secp256k1 public key to uncompressed format. @@ -337,11 +319,6 @@ def sign_id_nonce( Returns: 64-byte signature (r || s, each 32 bytes). """ - if len(private_key_bytes) != 32: - raise ValueError(f"Private key must be 32 bytes, got {len(private_key_bytes)}") - if len(dest_node_id) != 32: - raise ValueError(f"Dest node ID must be 32 bytes, got {len(dest_node_id)}") - # The signing input binds several values together per the spec: # # - Domain separator prevents cross-protocol signature reuse @@ -404,11 +381,6 @@ def verify_id_nonce_signature( Returns: True if signature is valid, False otherwise. """ - if len(signature) != ID_SIGNATURE_SIZE: - return False - if len(dest_node_id) != 32: - return False - # Build the signing input per spec: # domain-separator || challenge-data || ephemeral-pubkey || node-id-B input_data = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py index 2848ac9e..146e15ab 100644 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ b/src/lean_spec/subspecs/networking/discovery/handshake.py @@ -32,7 +32,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes32, Bytes33, Bytes64 +from lean_spec.types import Bytes16, Bytes32, Bytes33 from .config import HANDSHAKE_TIMEOUT_SECS from .crypto import ( @@ -41,7 +41,7 @@ verify_id_nonce_signature, ) from .keys import derive_keys_from_pubkey -from .messages import PacketFlag, Port +from .messages import IdNonce, Nonce, PacketFlag, Port from .packet import ( HandshakeAuthdata, WhoAreYouAuthdata, @@ -88,16 +88,16 @@ class PendingHandshake: remote_node_id: NodeId """32-byte node ID of the remote peer.""" - id_nonce: bytes | None = None + id_nonce: IdNonce | None = None """16-byte challenge nonce (set when WHOAREYOU sent/received).""" challenge_data: bytes | None = None """Full WHOAREYOU packet data for key derivation (masking-iv || static-header || authdata).""" - ephemeral_privkey: bytes | None = None + ephemeral_privkey: Bytes32 | None = None """32-byte ephemeral private key (set when we send HANDSHAKE).""" - challenge_nonce: bytes | None = None + challenge_nonce: Nonce | None = None """12-byte nonce from the packet that triggered WHOAREYOU.""" remote_enr_seq: SeqNumber = SeqNumber(0) @@ -111,7 +111,7 @@ def is_expired(self, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS) -> bool: return time.time() - self.started_at > timeout_secs -@dataclass +@dataclass(frozen=True, slots=True) class HandshakeResult: """Result of a completed handshake.""" @@ -145,18 +145,13 @@ class HandshakeManager: def __init__( self, local_node_id: NodeId, - local_private_key: bytes, + local_private_key: Bytes32, local_enr_rlp: bytes, local_enr_seq: SeqNumber, session_cache: SessionCache, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS, ): """Initialize handshake manager.""" - if len(local_node_id) != 32: - raise ValueError(f"Local node ID must be 32 bytes, got {len(local_node_id)}") - if len(local_private_key) != 32: - raise ValueError(f"Local private key must be 32 bytes, got {len(local_private_key)}") - self._local_node_id = local_node_id self._local_private_key = local_private_key self._local_enr_rlp = local_enr_rlp @@ -206,10 +201,10 @@ def start_handshake(self, remote_node_id: NodeId) -> PendingHandshake: def create_whoareyou( self, remote_node_id: NodeId, - request_nonce: bytes, + request_nonce: Nonce, remote_enr_seq: SeqNumber, - masking_iv: bytes, - ) -> tuple[bytes, bytes, bytes, bytes]: + masking_iv: Bytes16, + ) -> tuple[bytes, bytes, Nonce, bytes]: """ Create a WHOAREYOU packet in response to an undecryptable message. @@ -229,20 +224,20 @@ def create_whoareyou( - challenge_data: Full data for key derivation (masking-iv || static-header || authdata) """ id_nonce = generate_id_nonce() - authdata = encode_whoareyou_authdata(bytes(id_nonce), remote_enr_seq) + authdata = encode_whoareyou_authdata(id_nonce, remote_enr_seq) # Build challenge_data per spec: masking-iv || static-header || authdata. # # This data becomes the HKDF salt for session key derivation. # Both sides must use identical challenge_data to derive matching keys. static_header = encode_static_header(PacketFlag.WHOAREYOU, request_nonce, len(authdata)) - challenge_data = masking_iv + static_header + authdata + challenge_data = bytes(masking_iv) + static_header + authdata with self._lock: pending = PendingHandshake( state=HandshakeState.SENT_WHOAREYOU, remote_node_id=remote_node_id, - id_nonce=bytes(id_nonce), + id_nonce=id_nonce, challenge_data=challenge_data, challenge_nonce=request_nonce, remote_enr_seq=remote_enr_seq, @@ -255,11 +250,11 @@ def create_handshake_response( self, remote_node_id: NodeId, whoareyou: WhoAreYouAuthdata, - remote_pubkey: bytes, + remote_pubkey: Bytes33, challenge_data: bytes, remote_ip: str = "", remote_port: Port = _DEFAULT_PORT, - ) -> tuple[bytes, bytes, bytes]: + ) -> tuple[bytes, Bytes16, Bytes16]: """ Create a HANDSHAKE packet in response to WHOAREYOU. @@ -288,10 +283,10 @@ def create_handshake_response( # Per spec, the signature input includes the full challenge_data (not just id_nonce) # to bind the signature to this specific WHOAREYOU exchange. id_signature = sign_id_nonce( - Bytes32(self._local_private_key), + self._local_private_key, challenge_data, eph_pubkey, - Bytes32(remote_node_id), + remote_node_id, ) # Include our ENR if the remote's known seq is stale. @@ -314,8 +309,8 @@ def create_handshake_response( send_key, recv_key = derive_keys_from_pubkey( local_private_key=eph_privkey, remote_public_key=remote_pubkey, - local_node_id=Bytes32(self._local_node_id), - remote_node_id=Bytes32(remote_node_id), + local_node_id=self._local_node_id, + remote_node_id=remote_node_id, challenge_data=challenge_data, is_initiator=True, ) @@ -402,11 +397,11 @@ def handle_handshake( # The signature was computed over challenge_data (not just id_nonce), # and includes our node_id as the WHOAREYOU sender (node-id-B). if not verify_id_nonce_signature( - signature=Bytes64(handshake.id_signature), + signature=handshake.id_signature, challenge_data=challenge_data, - ephemeral_pubkey=Bytes33(handshake.eph_pubkey), - dest_node_id=Bytes32(self._local_node_id), - public_key_bytes=Bytes33(remote_pubkey), + ephemeral_pubkey=handshake.eph_pubkey, + dest_node_id=self._local_node_id, + public_key_bytes=remote_pubkey, ): raise HandshakeError("Invalid ID signature") @@ -415,10 +410,10 @@ def handle_handshake( # The challenge_data was saved when we sent WHOAREYOU. # Using the same data ensures both sides derive identical keys. send_key, recv_key = derive_keys_from_pubkey( - local_private_key=Bytes32(self._local_private_key), + local_private_key=self._local_private_key, remote_public_key=handshake.eph_pubkey, - local_node_id=Bytes32(self._local_node_id), - remote_node_id=Bytes32(remote_node_id), + local_node_id=self._local_node_id, + remote_node_id=remote_node_id, challenge_data=challenge_data, is_initiator=False, ) @@ -471,7 +466,7 @@ def cleanup_expired(self) -> int: del self._pending[node_id] return len(expired) - def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes | None: + def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> Bytes33 | None: """ Retrieve the remote node's static public key for signature verification. @@ -503,7 +498,7 @@ def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes # to ensure the ENR belongs to who we think sent it. computed_id = enr.compute_node_id() if computed_id is not None and bytes(computed_id) == node_id: - return bytes(enr.public_key) + return Bytes33(enr.public_key) except (ValueError, KeyError, IndexError): pass @@ -513,7 +508,7 @@ def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes # Use it if the handshake packet did not include an ENR. cached_enr = self._enr_cache.get(node_id) if cached_enr is not None and cached_enr.public_key is not None: - return bytes(cached_enr.public_key) + return Bytes33(cached_enr.public_key) return None diff --git a/src/lean_spec/subspecs/networking/discovery/keys.py b/src/lean_spec/subspecs/networking/discovery/keys.py index c664c9cb..dc974823 100644 --- a/src/lean_spec/subspecs/networking/discovery/keys.py +++ b/src/lean_spec/subspecs/networking/discovery/keys.py @@ -30,7 +30,7 @@ from Crypto.Hash import keccak -from lean_spec.types import Bytes16, Bytes32, Bytes33 +from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes65 from .crypto import ecdh_agree, pubkey_to_uncompressed @@ -75,13 +75,6 @@ def derive_keys( The initiator uses initiator_key to encrypt and recipient_key to decrypt. The recipient uses recipient_key to encrypt and initiator_key to decrypt. """ - if len(secret) != 33: - raise ValueError(f"Secret must be 33 bytes, got {len(secret)}") - if len(initiator_id) != 32: - raise ValueError(f"Initiator ID must be 32 bytes, got {len(initiator_id)}") - if len(recipient_id) != 32: - raise ValueError(f"Recipient ID must be 32 bytes, got {len(recipient_id)}") - # HKDF-Extract: PRK = HMAC-SHA256(salt, IKM). # # Using challenge_data as salt binds session keys to the specific WHOAREYOU. @@ -111,7 +104,7 @@ def derive_keys( def derive_keys_from_pubkey( local_private_key: Bytes32, - remote_public_key: bytes, + remote_public_key: Bytes33 | Bytes65, local_node_id: Bytes32, remote_node_id: Bytes32, challenge_data: bytes, @@ -125,7 +118,7 @@ def derive_keys_from_pubkey( Args: local_private_key: Our 32-byte secp256k1 private key. - remote_public_key: Peer's compressed public key. + remote_public_key: Peer's compressed (33-byte) or uncompressed (65-byte) public key. local_node_id: Our 32-byte node ID. remote_node_id: Peer's 32-byte node ID. challenge_data: WHOAREYOU packet data (masking-iv || static-header || authdata). @@ -154,7 +147,7 @@ def derive_keys_from_pubkey( return recipient_key, initiator_key -def compute_node_id(public_key_bytes: bytes) -> Bytes32: +def compute_node_id(public_key_bytes: Bytes33 | Bytes65) -> Bytes32: """ Compute node ID from public key. diff --git a/src/lean_spec/subspecs/networking/discovery/messages.py b/src/lean_spec/subspecs/networking/discovery/messages.py index a34cdd7b..b7feeda3 100644 --- a/src/lean_spec/subspecs/networking/discovery/messages.py +++ b/src/lean_spec/subspecs/networking/discovery/messages.py @@ -77,11 +77,12 @@ class Nonce(BaseBytes): LENGTH: ClassVar[int] = 12 -Distance = Uint16 -"""Log2 distance (0-256). Distance 0 returns the node's own ENR.""" +class Distance(Uint16): + """Log2 distance (0-256). Distance 0 returns the node's own ENR.""" -Port = Uint16 -"""UDP port number (0-65535).""" + +class Port(Uint16): + """UDP port number (0-65535).""" class PacketFlag(IntEnum): diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py index 45896c7e..18b73dab 100644 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ b/src/lean_spec/subspecs/networking/discovery/packet.py @@ -34,15 +34,13 @@ from dataclasses import dataclass from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16 +from lean_spec.types import Bytes12, Bytes16, Bytes33, Bytes64 from .config import MAX_PACKET_SIZE, MIN_PACKET_SIZE from .crypto import ( AES_KEY_SIZE, - COMPRESSED_PUBKEY_SIZE, CTR_IV_SIZE, GCM_NONCE_SIZE, - ID_SIGNATURE_SIZE, aes_ctr_decrypt, aes_ctr_encrypt, aes_gcm_decrypt, @@ -109,10 +107,10 @@ class HandshakeAuthdata: eph_key_size: int """Size of ephemeral public key. 33 for compressed secp256k1.""" - id_signature: bytes + id_signature: Bytes64 """ID nonce signature proving identity ownership.""" - eph_pubkey: bytes + eph_pubkey: Bytes33 """Ephemeral public key for ECDH.""" record: bytes | None @@ -122,10 +120,10 @@ class HandshakeAuthdata: def encode_packet( dest_node_id: NodeId, flag: PacketFlag, - nonce: bytes, + nonce: Nonce, authdata: bytes, message: bytes, - encryption_key: bytes | None = None, + encryption_key: Bytes16 | None = None, masking_iv: Bytes16 | None = None, ) -> bytes: """ @@ -144,11 +142,6 @@ def encode_packet( Returns: Complete encoded packet ready for UDP transmission. """ - if len(dest_node_id) != 32: - raise ValueError(f"Destination node ID must be 32 bytes, got {len(dest_node_id)}") - if len(nonce) != GCM_NONCE_SIZE: - raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - if masking_iv is None: # Fresh random IV for header masking. # @@ -181,9 +174,7 @@ def encode_packet( # The AAD binds the plaintext header to the encrypted message. # The recipient reconstructs this from the decoded header. message_ad = bytes(masking_iv) + header - encrypted_message = aes_gcm_encrypt( - Bytes16(encryption_key), Bytes12(nonce), message, message_ad - ) + encrypted_message = aes_gcm_encrypt(encryption_key, Bytes12(nonce), message, message_ad) # Assemble packet. packet = bytes(masking_iv) + masked_header + encrypted_message @@ -294,10 +285,10 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: raise ValueError(f"Handshake authdata truncated: {len(authdata)} < {expected_min}") offset = HANDSHAKE_HEADER_SIZE - id_signature = authdata[offset : offset + sig_size] + id_signature = Bytes64(authdata[offset : offset + sig_size]) offset += sig_size - eph_pubkey = authdata[offset : offset + eph_key_size] + eph_pubkey = Bytes33(authdata[offset : offset + eph_key_size]) offset += eph_key_size # Remaining bytes are the RLP-encoded ENR, included when the recipient's @@ -315,8 +306,8 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: def decrypt_message( - encryption_key: bytes, - nonce: bytes, + encryption_key: Bytes16, + nonce: Nonce, ciphertext: bytes, message_ad: bytes, ) -> bytes: @@ -332,27 +323,23 @@ def decrypt_message( Returns: Decrypted message plaintext. """ - return aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, message_ad) + return aes_gcm_decrypt(encryption_key, Bytes12(nonce), ciphertext, message_ad) def encode_message_authdata(src_id: NodeId) -> bytes: """Encode MESSAGE packet authdata.""" - if len(src_id) != 32: - raise ValueError(f"Source ID must be 32 bytes, got {len(src_id)}") return src_id -def encode_whoareyou_authdata(id_nonce: bytes, enr_seq: SeqNumber) -> bytes: +def encode_whoareyou_authdata(id_nonce: IdNonce, enr_seq: SeqNumber) -> bytes: """Encode WHOAREYOU packet authdata.""" - if len(id_nonce) != 16: - raise ValueError(f"ID nonce must be 16 bytes, got {len(id_nonce)}") return id_nonce + struct.pack(">Q", enr_seq) def encode_handshake_authdata( src_id: NodeId, - id_signature: bytes, - eph_pubkey: bytes, + id_signature: Bytes64, + eph_pubkey: Bytes33, record: bytes | None = None, ) -> bytes: """ @@ -367,15 +354,6 @@ def encode_handshake_authdata( Returns: Encoded authdata bytes. """ - if len(src_id) != 32: - raise ValueError(f"Source ID must be 32 bytes, got {len(src_id)}") - if len(id_signature) != ID_SIGNATURE_SIZE: - raise ValueError(f"Signature must be {ID_SIGNATURE_SIZE} bytes, got {len(id_signature)}") - if len(eph_pubkey) != COMPRESSED_PUBKEY_SIZE: - raise ValueError( - f"Ephemeral pubkey must be {COMPRESSED_PUBKEY_SIZE} bytes, got {len(eph_pubkey)}" - ) - authdata = src_id + bytes([len(id_signature), len(eph_pubkey)]) + id_signature + eph_pubkey if record is not None: @@ -394,7 +372,7 @@ def generate_id_nonce() -> IdNonce: return IdNonce(os.urandom(16)) -def encode_static_header(flag: PacketFlag, nonce: bytes, authdata_size: int) -> bytes: +def encode_static_header(flag: PacketFlag, nonce: Nonce, authdata_size: int) -> bytes: """Encode the 23-byte static header.""" return ( PROTOCOL_ID diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py index c3a0280f..8e17d3f5 100644 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ b/src/lean_spec/subspecs/networking/discovery/service.py @@ -33,6 +33,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes32, Bytes33 from lean_spec.types.uint import Uint8 from .codec import DiscoveryMessage @@ -90,7 +91,7 @@ class DiscoveryService: def __init__( self, local_enr: ENR, - private_key: bytes, + private_key: Bytes32, config: DiscoveryConfig | None = None, bootnodes: list[ENR] | None = None, ): @@ -103,7 +104,7 @@ def __init__( # Compute our node ID from public key. if local_enr.public_key is None: raise ValueError("Local ENR must have a public key") - self._local_node_id = NodeId(compute_node_id(bytes(local_enr.public_key))) + self._local_node_id = NodeId(compute_node_id(Bytes33(local_enr.public_key))) # Initialize routing table. self._routing_table = RoutingTable(local_id=NodeId(self._local_node_id)) diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py index 6708d556..f9fdffd9 100644 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ b/src/lean_spec/subspecs/networking/discovery/session.py @@ -25,6 +25,7 @@ from threading import Lock from lean_spec.subspecs.networking.types import NodeId +from lean_spec.types import Bytes16 from .config import BOND_EXPIRY_SECS from .messages import Port @@ -51,10 +52,10 @@ class Session: node_id: NodeId """Peer's 32-byte node ID.""" - send_key: bytes + send_key: Bytes16 """16-byte key for encrypting messages to this peer.""" - recv_key: bytes + recv_key: Bytes16 """16-byte key for decrypting messages from this peer.""" created_at: float @@ -134,8 +135,8 @@ def get(self, node_id: NodeId, ip: str = "", port: Port = _DEFAULT_PORT) -> Sess def create( self, node_id: NodeId, - send_key: bytes, - recv_key: bytes, + send_key: Bytes16, + recv_key: Bytes16, is_initiator: bool, ip: str = "", port: Port = _DEFAULT_PORT, @@ -157,13 +158,6 @@ def create( Returns: The newly created session. """ - if len(node_id) != 32: - raise ValueError(f"Node ID must be 32 bytes, got {len(node_id)}") - if len(send_key) != 16: - raise ValueError(f"Send key must be 16 bytes, got {len(send_key)}") - if len(recv_key) != 16: - raise ValueError(f"Recv key must be 16 bytes, got {len(recv_key)}") - key: SessionKey = (node_id, ip, port) now = time.time() session = Session( @@ -257,7 +251,7 @@ def _evict_oldest(self) -> None: del self.sessions[oldest_key] -@dataclass +@dataclass(slots=True) class BondCache: """ Cache tracking which nodes we have successfully bonded with. diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index b9fe97b2..b9d9f449 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -25,7 +25,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16 +from lean_spec.types import Bytes16, Bytes32, Bytes33 from .codec import ( DiscoveryMessage, @@ -78,7 +78,7 @@ class PendingRequest: sent_at: float """Timestamp when request was sent.""" - nonce: bytes + nonce: Nonce """Packet nonce (needed for WHOAREYOU handling).""" message: DiscoveryMessage @@ -105,7 +105,7 @@ class PendingMultiRequest: sent_at: float """Timestamp when request was sent.""" - nonce: bytes + nonce: Nonce """Packet nonce (needed for WHOAREYOU handling).""" message: DiscoveryMessage @@ -129,7 +129,7 @@ class DiscoveryProtocol(asyncio.DatagramProtocol): transport_handler: Parent transport for packet handling. """ - def __init__(self, transport_handler: DiscoveryTransport): + def __init__(self, transport_handler: DiscoveryTransport) -> None: """Initialize protocol handler.""" self._handler = transport_handler self._transport: asyncio.DatagramTransport | None = None @@ -172,7 +172,7 @@ class DiscoveryTransport: def __init__( self, local_node_id: NodeId, - local_private_key: bytes, + local_private_key: Bytes32, local_enr: ENR, config: DiscoveryConfig | None = None, ): @@ -390,7 +390,7 @@ async def _send_multi_response_request( request_id=request_id_bytes, dest_node_id=dest_node_id, sent_at=loop.time(), - nonce=bytes(nonce), + nonce=nonce, message=message, response_queue=response_queue, expected_total=None, @@ -498,7 +498,7 @@ async def _send_request( request_id=request_id_bytes, dest_node_id=dest_node_id, sent_at=loop.time(), - nonce=bytes(nonce), + nonce=nonce, message=message, future=future, ) @@ -547,7 +547,7 @@ def _build_message_packet( return encode_packet( dest_node_id=dest_node_id, flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=session.send_key, @@ -565,11 +565,11 @@ def _build_message_packet( # This approach avoids the need for session negotiation # before sending the first message. self._handshake_manager.start_handshake(dest_node_id) - dummy_key = os.urandom(16) + dummy_key = Bytes16(os.urandom(16)) return encode_packet( dest_node_id=dest_node_id, flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=dummy_key, @@ -626,7 +626,7 @@ async def _handle_whoareyou( # This links the challenge to the specific request that failed. pending = None for p in self._pending_requests.values(): - if p.nonce == bytes(header.nonce): + if p.nonce == header.nonce: pending = p break @@ -647,7 +647,7 @@ async def _handle_whoareyou( # We use the unmasked header, which we can reconstruct from the decoded values. masking_iv = raw_packet[:16] static_header = encode_static_header( - PacketFlag.WHOAREYOU, bytes(header.nonce), len(header.authdata) + PacketFlag.WHOAREYOU, header.nonce, len(header.authdata) ) challenge_data = masking_iv + static_header + header.authdata @@ -660,7 +660,7 @@ async def _handle_whoareyou( logger.debug("No ENR for %s, cannot complete handshake", remote_node_id.hex()[:16]) return - remote_pubkey = bytes(remote_enr.public_key) + remote_pubkey = Bytes33(remote_enr.public_key) # Build and send the HANDSHAKE response. try: @@ -685,7 +685,7 @@ async def _handle_whoareyou( packet = encode_packet( dest_node_id=remote_node_id, flag=PacketFlag.HANDSHAKE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=send_key, @@ -725,7 +725,7 @@ async def _handle_handshake( if len(message_bytes) > 0: plaintext = decrypt_message( encryption_key=result.session.recv_key, - nonce=bytes(header.nonce), + nonce=header.nonce, ciphertext=message_bytes, message_ad=message_ad, ) @@ -763,7 +763,7 @@ async def _handle_message( try: plaintext = decrypt_message( encryption_key=session.recv_key, - nonce=bytes(header.nonce), + nonce=header.nonce, ciphertext=message_bytes, message_ad=message_ad, ) @@ -828,11 +828,11 @@ async def _send_whoareyou( # # This IV is part of the challenge_data used for key derivation. # Both sides must use identical challenge_data to derive matching keys. - masking_iv = os.urandom(16) + masking_iv = Bytes16(os.urandom(16)) id_nonce, authdata, nonce, challenge_data = self._handshake_manager.create_whoareyou( remote_node_id=remote_node_id, - request_nonce=bytes(request_nonce), + request_nonce=request_nonce, remote_enr_seq=remote_enr_seq, masking_iv=masking_iv, ) @@ -844,7 +844,7 @@ async def _send_whoareyou( authdata=authdata, message=b"", encryption_key=None, - masking_iv=Bytes16(masking_iv), + masking_iv=masking_iv, ) self._transport.sendto(packet, addr) @@ -893,7 +893,7 @@ async def send_response( packet = encode_packet( dest_node_id=dest_node_id, flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=session.send_key, diff --git a/src/lean_spec/subspecs/networking/enr/enr.py b/src/lean_spec/subspecs/networking/enr/enr.py index 3afc188c..b0e6879c 100644 --- a/src/lean_spec/subspecs/networking/enr/enr.py +++ b/src/lean_spec/subspecs/networking/enr/enr.py @@ -50,16 +50,14 @@ encode_dss_signature, ) -from lean_spec.subspecs.networking.types import Multiaddr, NodeId, SeqNumber +from lean_spec.subspecs.networking.types import ForkDigest, Multiaddr, NodeId, SeqNumber, Version from lean_spec.types import ( - Bytes32, Bytes33, Bytes64, StrictBaseModel, Uint64, rlp, ) -from lean_spec.types.byte_arrays import Bytes4 from . import keys from .eth2 import AttestationSubnets, Eth2Data, SyncCommitteeSubnets @@ -159,8 +157,8 @@ def eth2_data(self) -> Eth2Data | None: eth2_bytes = self.get(keys.ETH2) if eth2_bytes and len(eth2_bytes) >= 16: return Eth2Data( - fork_digest=Bytes4(eth2_bytes[0:4]), - next_fork_version=Bytes4(eth2_bytes[4:8]), + fork_digest=ForkDigest(eth2_bytes[0:4]), + next_fork_version=Version(eth2_bytes[4:8]), next_fork_epoch=Uint64(int.from_bytes(eth2_bytes[8:16], "little")), ) return None @@ -211,7 +209,7 @@ def _build_content_items(self) -> list[bytes]: Returns [seq, k1, v1, k2, v2, ...] with keys sorted lexicographically. """ - sorted_keys = sorted(self.pairs.keys()) + sorted_keys = sorted(self.pairs) # Sequence number: minimal big-endian, empty bytes for zero. seq_bytes = self.seq.to_bytes(8, "big").lstrip(b"\x00") or b"" @@ -292,7 +290,7 @@ def compute_node_id(self) -> NodeId | None: # Hash the 64-byte x||y (excluding 0x04 prefix). k = keccak.new(digest_bits=256) k.update(uncompressed[1:]) - return Bytes32(k.digest()) + return NodeId(k.digest()) except (ValueError, TypeError): return None diff --git a/src/lean_spec/subspecs/networking/enr/eth2.py b/src/lean_spec/subspecs/networking/enr/eth2.py index 5df7ab12..da835d81 100644 --- a/src/lean_spec/subspecs/networking/enr/eth2.py +++ b/src/lean_spec/subspecs/networking/enr/eth2.py @@ -135,9 +135,9 @@ def is_subscribed(self, subnet_id: int) -> bool: raise ValueError(f"Sync subnet ID must be 0-3, got {subnet_id}") return bool(self.data[subnet_id]) - def subscribed_subnets(self) -> list[int]: + def subscribed_subnets(self) -> list[SubnetId]: """List of subscribed sync subnet IDs.""" - return [i for i in range(self.LENGTH) if self.data[i]] + return [SubnetId(i) for i in range(self.LENGTH) if self.data[i]] def subscription_count(self) -> int: """Number of subscribed sync subnets.""" diff --git a/src/lean_spec/subspecs/networking/enr/keys.py b/src/lean_spec/subspecs/networking/enr/keys.py index 6b83764f..5444a881 100644 --- a/src/lean_spec/subspecs/networking/enr/keys.py +++ b/src/lean_spec/subspecs/networking/enr/keys.py @@ -11,7 +11,7 @@ from typing import Final EnrKey = str -"""Type alias for ENR keys (can be any string/bytes per EIP-778)""" +"""ENR key identifier (any string/bytes per EIP-778).""" # EIP-778 Standard Keys ID: Final[EnrKey] = "id" diff --git a/src/lean_spec/subspecs/networking/gossipsub/behavior.py b/src/lean_spec/subspecs/networking/gossipsub/behavior.py index 27196d23..ef613220 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/behavior.py +++ b/src/lean_spec/subspecs/networking/gossipsub/behavior.py @@ -86,7 +86,7 @@ from lean_spec.subspecs.networking.transport import PeerId from lean_spec.subspecs.networking.transport.quic.stream_adapter import QuicStreamAdapter from lean_spec.subspecs.networking.varint import decode_varint, encode_varint -from lean_spec.types import Bytes20, Uint16 +from lean_spec.types import Uint16 logger = logging.getLogger(__name__) @@ -676,7 +676,7 @@ async def _handle_ihave(self, peer_id: PeerId, ihave: ControlIHave) -> None: # Message IDs must be exactly 20 bytes (SHA256 truncated to 160 bits). if len(msg_id) != 20: continue - msg_id_typed = Bytes20(msg_id) + msg_id_typed = MessageId(msg_id) if not self.seen_cache.has(msg_id_typed) and not self.message_cache.has(msg_id_typed): wanted.append(msg_id) @@ -695,7 +695,7 @@ async def _handle_iwant(self, peer_id: PeerId, iwant: ControlIWant) -> None: for msg_id in iwant.message_ids: if len(msg_id) != 20: continue - msg_id_typed = Bytes20(msg_id) + msg_id_typed = MessageId(msg_id) cached = self.message_cache.get(msg_id_typed) if cached: messages.append(Message(topic=cached.topic.decode("utf-8"), data=cached.raw_data)) @@ -718,7 +718,7 @@ def _handle_idontwant(self, peer_id: PeerId, idontwant: ControlIDontWant) -> Non for msg_id in idontwant.message_ids: if len(msg_id) != 20: continue - state.dont_want_ids.add(Bytes20(msg_id)) + state.dont_want_ids.add(MessageId(msg_id)) async def _heartbeat_loop(self) -> None: """Background heartbeat for mesh maintenance.""" diff --git a/src/lean_spec/subspecs/networking/gossipsub/message.py b/src/lean_spec/subspecs/networking/gossipsub/message.py index 952a3fa9..2037432d 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/message.py +++ b/src/lean_spec/subspecs/networking/gossipsub/message.py @@ -62,7 +62,6 @@ MESSAGE_DOMAIN_INVALID_SNAPPY, MESSAGE_DOMAIN_VALID_SNAPPY, ) -from lean_spec.types import Bytes20 from .types import MessageId @@ -172,7 +171,7 @@ def compute_id( preimage = bytes(domain) + len(topic).to_bytes(8, "little") + topic + data_for_hash - return Bytes20(hashlib.sha256(preimage).digest()[:20]) + return MessageId(hashlib.sha256(preimage).digest()[:20]) def __hash__(self) -> int: """Hash based on message ID. diff --git a/src/lean_spec/subspecs/networking/gossipsub/types.py b/src/lean_spec/subspecs/networking/gossipsub/types.py index 72361091..4c5aa0d5 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/types.py +++ b/src/lean_spec/subspecs/networking/gossipsub/types.py @@ -4,15 +4,16 @@ from lean_spec.types import Bytes20 -type MessageId = Bytes20 -"""20-byte message identifier. -Computed from message contents using SHA256:: +class MessageId(Bytes20): + """20-byte message identifier. - SHA256(domain + uint64_le(len(topic)) + topic + data)[:20] + Computed from message contents using SHA256:: -The domain byte distinguishes valid/invalid snappy compression. -""" + SHA256(domain + uint64_le(len(topic)) + topic + data)[:20] + + The domain byte distinguishes valid/invalid snappy compression. + """ type TopicId = str diff --git a/src/lean_spec/subspecs/networking/peer.py b/src/lean_spec/subspecs/networking/peer.py index ac309073..13e16d50 100644 --- a/src/lean_spec/subspecs/networking/peer.py +++ b/src/lean_spec/subspecs/networking/peer.py @@ -7,14 +7,14 @@ from typing import TYPE_CHECKING from .transport import PeerId -from .types import ConnectionState, Direction, Multiaddr +from .types import ConnectionState, Direction, ForkDigest, Multiaddr if TYPE_CHECKING: from .enr import ENR from .reqresp import Status -@dataclass +@dataclass(slots=True) class PeerInfo: """ Information about a known peer. @@ -59,7 +59,7 @@ def update_last_seen(self) -> None: self.last_seen = time() @property - def fork_digest(self) -> bytes | None: + def fork_digest(self) -> ForkDigest | None: """ Get the peer's fork_digest from cached ENR. @@ -71,4 +71,4 @@ def fork_digest(self) -> bytes | None: eth2_data = self.enr.eth2_data if eth2_data is None: return None - return bytes(eth2_data.fork_digest) + return eth2_data.fork_digest diff --git a/src/lean_spec/subspecs/networking/service/service.py b/src/lean_spec/subspecs/networking/service/service.py index 28fa9bbf..b06e2699 100644 --- a/src/lean_spec/subspecs/networking/service/service.py +++ b/src/lean_spec/subspecs/networking/service/service.py @@ -232,7 +232,7 @@ async def publish_attestation( compressed = frame_compress(ssz_bytes) await self.event_source.publish(str(topic), compressed) - logger.debug("Published attestation for slot %s", attestation.message.slot) + logger.debug("Published attestation for slot %s", attestation.data.slot) async def publish_aggregated_attestation( self, signed_attestation: SignedAggregatedAttestation diff --git a/src/lean_spec/subspecs/networking/transport/identity/keypair.py b/src/lean_spec/subspecs/networking/transport/identity/keypair.py index 7b23c3b1..4b09b838 100644 --- a/src/lean_spec/subspecs/networking/transport/identity/keypair.py +++ b/src/lean_spec/subspecs/networking/transport/identity/keypair.py @@ -16,6 +16,8 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec +from lean_spec.types import Bytes32, Bytes33 + from ..peer_id import KeyType, PeerId, PublicKeyProto __all__ = [ @@ -28,7 +30,7 @@ class IdentityKeypair: """ secp256k1 keypair for libp2p identity. - Used to derive PeerId and sign identity proofs during Noise handshake. + Used to derive PeerId and sign identity proofs during QUIC TLS handshake. Attributes: private_key: The secp256k1 private key. @@ -48,7 +50,7 @@ def generate(cls) -> IdentityKeypair: return cls(private_key=private_key) @classmethod - def from_bytes(cls, data: bytes) -> IdentityKeypair: + def from_bytes(cls, data: Bytes32) -> IdentityKeypair: """ Load keypair from raw private key bytes. @@ -57,20 +59,14 @@ def from_bytes(cls, data: bytes) -> IdentityKeypair: Returns: Identity keypair. - - Raises: - ValueError: If data is not a valid secp256k1 private key. """ - if len(data) != 32: - raise ValueError(f"Expected 32 bytes, got {len(data)}") - private_key = ec.derive_private_key( int.from_bytes(data, "big"), ec.SECP256K1(), ) return cls(private_key=private_key) - def private_key_bytes(self) -> bytes: + def private_key_bytes(self) -> Bytes32: """ Return the raw 32-byte private key. @@ -78,9 +74,9 @@ def private_key_bytes(self) -> bytes: 32-byte private key scalar. """ private_numbers = self.private_key.private_numbers() - return private_numbers.private_value.to_bytes(32, "big") + return Bytes32(private_numbers.private_value.to_bytes(32, "big")) - def public_key_bytes(self) -> bytes: + def public_key_bytes(self) -> Bytes33: """ Return the compressed secp256k1 public key (33 bytes). @@ -91,9 +87,11 @@ def public_key_bytes(self) -> bytes: 33-byte compressed public key. """ public_key = self.private_key.public_key() - return public_key.public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + return Bytes33( + public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) def sign(self, message: bytes) -> bytes: @@ -128,7 +126,7 @@ def to_peer_id(self) -> PeerId: def verify_signature( - public_key_bytes: bytes, + public_key_bytes: Bytes33, message: bytes, signature: bytes, ) -> bool: diff --git a/src/lean_spec/subspecs/networking/transport/identity/signature.py b/src/lean_spec/subspecs/networking/transport/identity/signature.py index 59fe2eb9..5c98c501 100644 --- a/src/lean_spec/subspecs/networking/transport/identity/signature.py +++ b/src/lean_spec/subspecs/networking/transport/identity/signature.py @@ -1,11 +1,11 @@ """ -Identity proof for libp2p-noise handshake. +Identity proof for libp2p peer authentication. -During the Noise handshake, peers must prove they own their claimed -libp2p identity key by signing the Noise static public key. +During the QUIC TLS handshake, peers must prove they own their claimed +libp2p identity key by signing the TLS public key. -The signature format follows the libp2p-noise specification: - message = "noise-libp2p-static-key:" || noise_public_key +The signature format follows the libp2p specification: + message = "noise-libp2p-static-key:" || public_key signature = ECDSA-SHA256(identity_private_key, message) References: @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Final +from lean_spec.types import Bytes32, Bytes33 + from .keypair import verify_signature if TYPE_CHECKING: @@ -34,43 +36,42 @@ def create_identity_proof( identity_key: IdentityKeypair, - noise_public_key: bytes, + public_key: Bytes32, ) -> bytes: """ - Create identity proof signature for Noise handshake. + Create identity proof signature for peer authentication. Proves that the owner of the identity key (secp256k1) also controls - the Noise static key (X25519). This binding prevents man-in-the-middle - attacks where an attacker substitutes their Noise key. + the TLS static key. This binding prevents man-in-the-middle attacks. Args: identity_key: The secp256k1 identity keypair. - noise_public_key: The 32-byte X25519 Noise static public key. + public_key: The 32-byte TLS public key. Returns: DER-encoded ECDSA signature. """ - message = NOISE_IDENTITY_PREFIX + noise_public_key + message = NOISE_IDENTITY_PREFIX + public_key return identity_key.sign(message) def verify_identity_proof( - identity_public_key: bytes, - noise_public_key: bytes, + identity_public_key: Bytes33, + public_key: Bytes32, signature: bytes, ) -> bool: """ Verify identity proof signature. - Called during Noise handshake to verify the remote peer's identity claim. + Called during QUIC TLS handshake to verify the remote peer's identity claim. Args: identity_public_key: 33-byte compressed secp256k1 public key. - noise_public_key: 32-byte X25519 Noise static public key. + public_key: 32-byte TLS public key. signature: DER-encoded ECDSA signature. Returns: True if the signature is valid, False otherwise. """ - message = NOISE_IDENTITY_PREFIX + noise_public_key + message = NOISE_IDENTITY_PREFIX + public_key return verify_signature(identity_public_key, message, signature) diff --git a/src/lean_spec/subspecs/networking/transport/peer_id.py b/src/lean_spec/subspecs/networking/transport/peer_id.py index 05e6e277..2a1f4848 100644 --- a/src/lean_spec/subspecs/networking/transport/peer_id.py +++ b/src/lean_spec/subspecs/networking/transport/peer_id.py @@ -38,6 +38,7 @@ from typing import Final from lean_spec.subspecs.networking import varint +from lean_spec.types import Bytes33 __all__ = [ # Main types @@ -408,7 +409,7 @@ def from_public_key(cls, public_key: PublicKeyProto) -> PeerId: return cls(multihash=mh.encode()) @classmethod - def from_secp256k1(cls, public_key_bytes: bytes) -> PeerId: + def from_secp256k1(cls, public_key_bytes: Bytes33) -> PeerId: """ Derive PeerId from a secp256k1 compressed public key. @@ -421,14 +422,6 @@ def from_secp256k1(cls, public_key_bytes: bytes) -> PeerId: Returns: Derived PeerId (starts with "16Uiu2..." for secp256k1). - - Raises: - ValueError: If public key is not 33 bytes. """ - if len(public_key_bytes) != 33: - raise ValueError( - f"secp256k1 compressed key must be 33 bytes, got {len(public_key_bytes)}" - ) - proto = PublicKeyProto(key_type=KeyType.SECP256K1, key_data=public_key_bytes) return cls.from_public_key(proto) diff --git a/src/lean_spec/subspecs/networking/types.py b/src/lean_spec/subspecs/networking/types.py index a7fb86b3..208b5367 100644 --- a/src/lean_spec/subspecs/networking/types.py +++ b/src/lean_spec/subspecs/networking/types.py @@ -5,23 +5,27 @@ from lean_spec.types import Uint64 from lean_spec.types.byte_arrays import Bytes1, Bytes4, Bytes32 -DomainType = Bytes1 -"""1-byte domain for message-id isolation in Gossipsub. -The domain is a single byte prepended to the message hash to compute the gossip message ID. +class DomainType(Bytes1): + """1-byte domain for message-id isolation in Gossipsub. -- Valid messages use 0x01, -- Invalid messages use 0x00. -""" + The domain is a single byte prepended to the message hash to compute the gossip message ID. -NodeId = Bytes32 -"""32-byte node identifier for Discovery v5, derived from ``keccak256(pubkey)``.""" + - Valid messages use 0x01, + - Invalid messages use 0x00. + """ + + +class NodeId(Bytes32): + """32-byte node identifier for Discovery v5, derived from ``keccak256(pubkey)``.""" + + +class ForkDigest(Bytes4): + """4-byte fork identifier ensuring network isolation between forks.""" -ForkDigest = Bytes4 -"""4-byte fork identifier ensuring network isolation between forks.""" -Version = Bytes4 -"""4-byte fork version number (e.g., 0x01000000 for Phase0).""" +class Version(Bytes4): + """4-byte fork version number (e.g., 0x01000000 for Phase0).""" class SeqNumber(Uint64): diff --git a/src/lean_spec/subspecs/node/node.py b/src/lean_spec/subspecs/node/node.py index 3372b02c..7466fb6e 100644 --- a/src/lean_spec/subspecs/node/node.py +++ b/src/lean_spec/subspecs/node/node.py @@ -205,7 +205,7 @@ def from_genesis(cls, config: NodeConfig) -> Node: # Initialize forkchoice store. # # Genesis block is both justified and finalized. - store = Store.get_forkchoice_store(state, block, validator_id) + store = state.to_forkchoice_store(block, validator_id) # Persist genesis to database if available. if database is not None: diff --git a/src/lean_spec/subspecs/ssz/hash.py b/src/lean_spec/subspecs/ssz/hash.py index b4baaac8..d7b6b292 100644 --- a/src/lean_spec/subspecs/ssz/hash.py +++ b/src/lean_spec/subspecs/ssz/hash.py @@ -149,9 +149,7 @@ def _htr_list(value: SSZList) -> Bytes32: @hash_tree_root.register def _htr_container(value: Container) -> Bytes32: # Preserve declared field order from the Pydantic model - return merkleize( - [hash_tree_root(getattr(value, fname)) for fname in type(value).model_fields.keys()] - ) + return merkleize([hash_tree_root(getattr(value, fname)) for fname in type(value).model_fields]) @hash_tree_root.register diff --git a/src/lean_spec/subspecs/sync/block_cache.py b/src/lean_spec/subspecs/sync/block_cache.py index 8ce51871..5728bcad 100644 --- a/src/lean_spec/subspecs/sync/block_cache.py +++ b/src/lean_spec/subspecs/sync/block_cache.py @@ -345,7 +345,7 @@ def get_children(self, parent_root: Bytes32) -> list[PendingBlock]: # parent P, we must process A before B. return sorted(children, key=lambda p: p.slot) - def get_processable(self, store: "Store") -> list[PendingBlock]: + def get_processable(self, store: Store) -> list[PendingBlock]: """ Get blocks whose parents exist in the Store. diff --git a/src/lean_spec/subspecs/sync/checkpoint_sync.py b/src/lean_spec/subspecs/sync/checkpoint_sync.py index 8e418804..18fa9b26 100644 --- a/src/lean_spec/subspecs/sync/checkpoint_sync.py +++ b/src/lean_spec/subspecs/sync/checkpoint_sync.py @@ -44,7 +44,7 @@ class CheckpointSyncError(Exception): """ -async def fetch_finalized_state(url: str, state_class: type["State"]) -> "State": +async def fetch_finalized_state(url: str, state_class: type[State]) -> State: """ Fetch finalized state from a node via checkpoint sync. @@ -102,7 +102,7 @@ async def fetch_finalized_state(url: str, state_class: type["State"]) -> "State" raise CheckpointSyncError(f"Failed to fetch state: {e}") from e -def verify_checkpoint_state(state: "State") -> bool: +def verify_checkpoint_state(state: State) -> bool: """ Verify that a checkpoint state is structurally valid. diff --git a/src/lean_spec/subspecs/sync/head_sync.py b/src/lean_spec/subspecs/sync/head_sync.py index a04cf282..198564ee 100644 --- a/src/lean_spec/subspecs/sync/head_sync.py +++ b/src/lean_spec/subspecs/sync/head_sync.py @@ -45,8 +45,8 @@ from __future__ import annotations import logging +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Callable from lean_spec.subspecs.containers import SignedBlockWithAttestation from lean_spec.subspecs.forkchoice import Store diff --git a/src/lean_spec/subspecs/sync/states.py b/src/lean_spec/subspecs/sync/states.py index 20fb1317..2e672a92 100644 --- a/src/lean_spec/subspecs/sync/states.py +++ b/src/lean_spec/subspecs/sync/states.py @@ -98,7 +98,7 @@ class SyncState(Enum): - Falls back to SYNCING if gaps appear """ - def can_transition_to(self, target: "SyncState") -> bool: + def can_transition_to(self, target: SyncState) -> bool: """ Check if transition to target state is valid. diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index dfa1a685..6333e36b 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -492,7 +492,7 @@ def _sign_attestation( return SignedAttestation( validator_id=validator_index, - message=attestation_data, + data=attestation_data, signature=signature, ) diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index c9ad8060..e7a5a795 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -79,9 +79,9 @@ def _serialize_as_bytes(self) -> str: def verify( self, public_key: PublicKey, - slot: "Slot", - message: "Bytes32", - scheme: "GeneralizedXmssScheme", + slot: Slot, + message: Bytes32, + scheme: GeneralizedXmssScheme, ) -> bool: """ Verify the signature using XMSS verification algorithm. diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 2c1bbe25..831948b8 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -62,7 +62,7 @@ class GeneralizedXmssScheme(StrictBaseModel): """Random data generator for key generation.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "GeneralizedXmssScheme": + def _validate_strict_types(self) -> GeneralizedXmssScheme: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types( self, diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index e6c823f2..13deb074 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -65,7 +65,7 @@ class MessageHasher(StrictBaseModel): """Poseidon hash engine.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "MessageHasher": + def _validate_strict_types(self) -> MessageHasher: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) return self diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 92f72cc6..312798b1 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -47,7 +47,7 @@ class PoseidonXmss(StrictBaseModel): """Poseidon2 parameters for 24-width permutation.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "PoseidonXmss": + def _validate_strict_types(self) -> PoseidonXmss: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, params16=Poseidon2Params, params24=Poseidon2Params) return self diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 8dd455a9..8b9a8c64 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -110,7 +110,7 @@ class Prf(StrictBaseModel): """Configuration parameters for the PRF.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "Prf": + def _validate_strict_types(self) -> Prf: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, config=XmssConfig) return self diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index c1f1f6a2..3b3a7aa5 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -325,14 +325,14 @@ def new_bottom_tree( @classmethod def from_prf_key( cls, - prf: "Prf", - hasher: "TweakHasher", - rand: "Rand", - config: "XmssConfig", + prf: Prf, + hasher: TweakHasher, + rand: Rand, + config: XmssConfig, prf_key: PRFKey, bottom_tree_index: Uint64, parameter: Parameter, - ) -> "HashSubTree": + ) -> HashSubTree: """ Generates a single bottom tree on-demand from the PRF key. @@ -544,7 +544,7 @@ def combined_path( def verify_path( - hasher: "TweakHasher", + hasher: TweakHasher, parameter: Parameter, root: HashDigestVector, position: Uint64, diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index ec8b0538..3657f051 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -87,7 +87,7 @@ class TweakHasher(StrictBaseModel): """Poseidon permutation instance for hashing.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "TweakHasher": + def _validate_strict_types(self) -> TweakHasher: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) return self diff --git a/tests/consensus/devnet/ssz/test_consensus_containers.py b/tests/consensus/devnet/ssz/test_consensus_containers.py index 9a1ba7aa..5f390748 100644 --- a/tests/consensus/devnet/ssz/test_consensus_containers.py +++ b/tests/consensus/devnet/ssz/test_consensus_containers.py @@ -127,7 +127,7 @@ def test_signed_attestation_minimal(ssz: SSZTestFiller) -> None: type_name="SignedAttestation", value=SignedAttestation( validator_id=ValidatorIndex(0), - message=_zero_attestation_data(), + data=_zero_attestation_data(), signature=_empty_signature(), ), ) diff --git a/tests/lean_spec/conftest.py b/tests/lean_spec/conftest.py index 20249991..d544f6db 100644 --- a/tests/lean_spec/conftest.py +++ b/tests/lean_spec/conftest.py @@ -71,8 +71,7 @@ def genesis_block(genesis_state: State) -> Block: @pytest.fixture def base_store(genesis_state: State, genesis_block: Block) -> Store: """Fork choice store initialized with genesis.""" - return Store.get_forkchoice_store( - genesis_state, + return genesis_state.to_forkchoice_store( genesis_block, validator_id=ValidatorIndex(0), ) @@ -108,4 +107,4 @@ def keyed_store(keyed_genesis: GenesisData) -> Store: @pytest.fixture def observer_store(keyed_genesis_state: State, keyed_genesis_block: Block) -> Store: """Fork choice store with validator_id=None (non-validator observer).""" - return Store.get_forkchoice_store(keyed_genesis_state, keyed_genesis_block, validator_id=None) + return keyed_genesis_state.to_forkchoice_store(keyed_genesis_block, validator_id=None) diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index c07becac..edc2deb2 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -30,7 +30,7 @@ from lean_spec.subspecs.containers.block.types import AggregatedAttestations, AttestationSignatures from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import Validators -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.koalabear import Fp from lean_spec.subspecs.networking import PeerId @@ -240,7 +240,9 @@ def make_aggregated_attestation( ) return AggregatedAttestation( - aggregation_bits=AggregationBits.from_validator_indices(participant_ids), + aggregation_bits=AggregationBits.from_validator_indices( + ValidatorIndices(data=participant_ids) + ), data=data, ) @@ -264,7 +266,7 @@ def make_signed_attestation( ) return SignedAttestation( validator_id=validator, - message=attestation_data, + data=attestation_data, signature=make_mock_signature(), ) @@ -326,7 +328,7 @@ def make_genesis_data( validators = make_validators(num_validators) genesis_state = make_genesis_state(validators=validators, genesis_time=genesis_time) genesis_block = make_genesis_block(genesis_state) - store = Store.get_forkchoice_store(genesis_state, genesis_block, validator_id=validator_id) + store = genesis_state.to_forkchoice_store(genesis_block, validator_id=validator_id) return GenesisData(store, genesis_state, genesis_block) @@ -423,7 +425,7 @@ def make_aggregated_proof( """Create a valid aggregated signature proof for the given participants.""" data_root = attestation_data.data_root_bytes() return AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices(ValidatorIndices(data=participants)), public_keys=[key_manager.get_public_key(vid) for vid in participants], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants diff --git a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py index 83dfb8bd..d50215dc 100644 --- a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py +++ b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py @@ -61,7 +61,9 @@ def test_aggregated_attestation_structure(self) -> None: source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), ) - bits = AggregationBits.from_validator_indices([ValidatorIndex(2), ValidatorIndex(7)]) + bits = AggregationBits.from_validator_indices( + ValidatorIndices(data=[ValidatorIndex(2), ValidatorIndex(7)]) + ) agg = AggregatedAttestation(aggregation_bits=bits, data=att_data) # Verify we can extract validator indices @@ -78,12 +80,12 @@ def test_aggregated_attestation_with_many_validators(self) -> None: source=Checkpoint(root=Bytes32.zero(), slot=Slot(7)), ) - validator_ids = [ValidatorIndex(i) for i in [0, 5, 10, 15, 20, 25]] + validator_ids = ValidatorIndices(data=[ValidatorIndex(i) for i in [0, 5, 10, 15, 20, 25]]) bits = AggregationBits.from_validator_indices(validator_ids) agg = AggregatedAttestation(aggregation_bits=bits, data=att_data) recovered = agg.aggregation_bits.to_validator_indices() - assert recovered == ValidatorIndices(data=validator_ids) + assert recovered == validator_ids class TestAggregateByData: diff --git a/tests/lean_spec/subspecs/containers/test_state_process_attestations.py b/tests/lean_spec/subspecs/containers/test_state_process_attestations.py index 33f9bca4..b03aa08e 100644 --- a/tests/lean_spec/subspecs/containers/test_state_process_attestations.py +++ b/tests/lean_spec/subspecs/containers/test_state_process_attestations.py @@ -50,7 +50,7 @@ HistoricalBlockHashes, JustifiedSlots, ) -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.types import Boolean from tests.lean_spec.helpers import make_bytes32, make_genesis_state @@ -127,7 +127,7 @@ def test_attestation_with_target_beyond_history_is_silently_rejected(self) -> No attestation = AggregatedAttestation( # Two validators participate in this attestation. aggregation_bits=AggregationBits.from_validator_indices( - [ValidatorIndex(0), ValidatorIndex(1)] + ValidatorIndices(data=[ValidatorIndex(0), ValidatorIndex(1)]) ), data=att_data, ) @@ -209,7 +209,7 @@ def test_attestation_with_source_beyond_history_is_silently_rejected(self) -> No attestation = AggregatedAttestation( aggregation_bits=AggregationBits.from_validator_indices( - [ValidatorIndex(0), ValidatorIndex(1)] + ValidatorIndices(data=[ValidatorIndex(0), ValidatorIndex(1)]) ), data=att_data, ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py b/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py index 4a06e3af..4283d038 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py +++ b/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py @@ -157,7 +157,7 @@ def test_safe_target_requires_supermajority( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) # Process as gossip (requires aggregator flag) @@ -202,7 +202,7 @@ def test_safe_target_advances_with_supermajority( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -242,7 +242,7 @@ def test_update_safe_target_uses_new_attestations( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -296,7 +296,7 @@ def test_justification_with_supermajority_attestations( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -378,7 +378,7 @@ def test_justification_tracking_with_multiple_targets( sig = key_manager.sign_attestation_data(vid, attestation_data_head) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data_head, + data=attestation_data_head, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -427,7 +427,7 @@ def test_finalization_after_consecutive_justification( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -521,7 +521,7 @@ def test_full_attestation_cycle( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) # Process as gossip diff --git a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py index e2cd1e28..ff3c1ebb 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py +++ b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py @@ -16,7 +16,7 @@ ) from lean_spec.subspecs.containers.checkpoint import Checkpoint from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof, SignatureKey from lean_spec.types import Bytes32, Uint64 from tests.lean_spec.helpers import ( @@ -173,7 +173,7 @@ def test_same_subnet_stores_signature(self, key_manager: XmssKeyManager) -> None signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -216,7 +216,7 @@ def test_cross_subnet_ignores_signature(self, key_manager: XmssKeyManager) -> No signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -251,7 +251,7 @@ def test_non_aggregator_never_stores_signature(self, key_manager: XmssKeyManager signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -286,7 +286,7 @@ def test_attestation_data_always_stored(self, key_manager: XmssKeyManager) -> No signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -336,7 +336,9 @@ def test_valid_proof_stored_correctly(self, key_manager: XmssKeyManager) -> None # Create valid aggregated proof proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants @@ -377,7 +379,9 @@ def test_attestation_data_stored_by_root(self, key_manager: XmssKeyManager) -> N data_root = attestation_data.data_root_bytes() proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants @@ -413,7 +417,9 @@ def test_invalid_proof_rejected(self, key_manager: XmssKeyManager) -> None: # Create proof with WRONG signers (validator 3 signs instead of 2) proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(claimed_participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=claimed_participants) + ), public_keys=[key_manager.get_public_key(vid) for vid in actual_signers], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in actual_signers @@ -446,7 +452,9 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager) -> None: # First proof: validators 1 and 2 participants_1 = [ValidatorIndex(1), ValidatorIndex(2)] proof_1 = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants_1), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants_1) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants_1], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants_1 @@ -458,7 +466,9 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager) -> None: # Second proof: validators 1 and 3 (validator 1 overlaps) participants_2 = [ValidatorIndex(1), ValidatorIndex(3)] proof_2 = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants_2), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants_2) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants_2], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants_2 @@ -797,7 +807,7 @@ def test_gossip_to_aggregation_to_storage(self, key_manager: XmssKeyManager) -> for vid in attesting_validators: signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(vid, attestation_data), ) store = store.on_gossip_attestation( diff --git a/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py b/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py index dd36419d..4a5cacfb 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py +++ b/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py @@ -2,7 +2,7 @@ from lean_spec.subspecs.containers.attestation import AggregationBits from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof, SignatureKey from lean_spec.types import Bytes32 @@ -149,7 +149,9 @@ def test_prunes_related_structures_together(pruning_store: Store) -> None: # Create mock aggregated proof (empty proof data for testing) mock_proof = AggregatedSignatureProof( - participants=AggregationBits.from_validator_indices([ValidatorIndex(1)]), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=[ValidatorIndex(1)]) + ), proof_data=ByteListMiB(data=b""), ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_time_management.py b/tests/lean_spec/subspecs/forkchoice/test_time_management.py index 2adbd852..4811c0f1 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_time_management.py +++ b/tests/lean_spec/subspecs/forkchoice/test_time_management.py @@ -21,14 +21,14 @@ class TestGetForkchoiceStore: - """Test Store.get_forkchoice_store() time initialization.""" + """Test State.to_forkchoice_store() time initialization.""" @settings(max_examples=100) @given(anchor_slot=st.integers(min_value=0, max_value=10000)) def test_store_time_from_anchor_slot(self, anchor_slot: int) -> None: - """get_forkchoice_store sets time = anchor_slot * INTERVALS_PER_SLOT.""" + """to_forkchoice_store sets time = anchor_slot * INTERVALS_PER_SLOT.""" # Must create its own state and block instead of using sample_store() - # because sample_store() bypasses get_forkchoice_store() with hardcoded time. + # because sample_store() bypasses to_forkchoice_store() with hardcoded time. state = State.generate_genesis( genesis_time=Uint64(1000), validators=Validators(data=[]), @@ -43,9 +43,8 @@ def test_store_time_from_anchor_slot(self, anchor_slot: int) -> None: body=make_empty_block_body(), ) - store = Store.get_forkchoice_store( - anchor_state=state, - anchor_block=anchor_block, + store = state.to_forkchoice_store( + anchor_block, validator_id=TEST_VALIDATOR_ID, ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_validator.py b/tests/lean_spec/subspecs/forkchoice/test_validator.py index 57824518..3396eb9a 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_validator.py +++ b/tests/lean_spec/subspecs/forkchoice/test_validator.py @@ -67,7 +67,7 @@ def test_produce_block_with_attestations( ) signed_5 = SignedAttestation( validator_id=ValidatorIndex(5), - message=data_5, + data=data_5, signature=key_manager.sign_attestation_data(ValidatorIndex(5), data_5), ) data_6 = AttestationData( @@ -78,15 +78,15 @@ def test_produce_block_with_attestations( ) signed_6 = SignedAttestation( validator_id=ValidatorIndex(6), - message=data_6, + data=data_6, signature=key_manager.sign_attestation_data(ValidatorIndex(6), data_6), ) - data_root_5 = signed_5.message.data_root_bytes() - data_root_6 = signed_6.message.data_root_bytes() + data_root_5 = signed_5.data.data_root_bytes() + data_root_6 = signed_6.data.data_root_bytes() - proof_5 = make_aggregated_proof(key_manager, [ValidatorIndex(5)], signed_5.message) - proof_6 = make_aggregated_proof(key_manager, [ValidatorIndex(6)], signed_6.message) + proof_5 = make_aggregated_proof(key_manager, [ValidatorIndex(5)], signed_5.data) + proof_6 = make_aggregated_proof(key_manager, [ValidatorIndex(6)], signed_6.data) sig_key_5 = SignatureKey(ValidatorIndex(5), data_root_5) sig_key_6 = SignatureKey(ValidatorIndex(6), data_root_6) @@ -98,8 +98,8 @@ def test_produce_block_with_attestations( sig_key_6: [proof_6], }, "attestation_data_by_root": { - data_root_5: signed_5.message, - data_root_6: signed_6.message, + data_root_5: signed_5.data, + data_root_6: signed_6.data, }, "gossip_signatures": { sig_key_5: signed_5.signature, @@ -216,18 +216,18 @@ def test_produce_block_state_consistency( ) signed_7 = SignedAttestation( validator_id=ValidatorIndex(7), - message=data_7, + data=data_7, signature=key_manager.sign_attestation_data(ValidatorIndex(7), data_7), ) - data_root_7 = signed_7.message.data_root_bytes() - proof_7 = make_aggregated_proof(key_manager, [ValidatorIndex(7)], signed_7.message) + data_root_7 = signed_7.data.data_root_bytes() + proof_7 = make_aggregated_proof(key_manager, [ValidatorIndex(7)], signed_7.data) sig_key_7 = SignatureKey(ValidatorIndex(7), data_root_7) sample_store = sample_store.model_copy( update={ "latest_known_aggregated_payloads": {sig_key_7: [proof_7]}, - "attestation_data_by_root": {data_root_7: signed_7.message}, + "attestation_data_by_root": {data_root_7: signed_7.data}, "gossip_signatures": {sig_key_7: signed_7.signature}, } ) diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py index 84aa2325..03268196 100644 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ b/tests/lean_spec/subspecs/networking/discovery/conftest.py @@ -4,27 +4,34 @@ import pytest +from lean_spec.subspecs.networking.discovery.messages import IdNonce from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes64 +from lean_spec.types import Bytes32, Bytes33, Bytes64 # From devp2p test vectors -NODE_A_PRIVKEY = bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") +NODE_A_PRIVKEY = Bytes32( + bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") +) NODE_A_ID = NodeId( bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") ) -NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") +NODE_B_PRIVKEY = Bytes32( + bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") +) NODE_B_ID = NodeId( bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") ) -NODE_B_PUBKEY = bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") +NODE_B_PUBKEY = Bytes33( + bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") +) # Spec id-nonce used in WHOAREYOU test vectors. -SPEC_ID_NONCE = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") +SPEC_ID_NONCE = IdNonce(bytes.fromhex("0102030405060708090a0b0c0d0e0f10")) @pytest.fixture -def local_private_key() -> bytes: +def local_private_key() -> Bytes32: """Node B's private key from devp2p test vectors.""" return NODE_B_PRIVKEY diff --git a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py index 413c9a47..e7f5cb6b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py @@ -56,16 +56,6 @@ def test_different_ivs_produce_different_ciphertext(self): assert ct1 != ct2 - def test_invalid_key_length_raises(self): - """Test that invalid key length raises ValueError.""" - with pytest.raises(ValueError, match="Key must be 16 bytes"): - aes_ctr_encrypt(bytes(15), bytes(16), b"data") # type: ignore[arg-type] - - def test_invalid_iv_length_raises(self): - """Test that invalid IV length raises ValueError.""" - with pytest.raises(ValueError, match="IV must be 16 bytes"): - aes_ctr_encrypt(bytes(16), bytes(15), b"data") # type: ignore[arg-type] - class TestAesGcm: """Tests for AES-GCM encryption/decryption.""" @@ -105,16 +95,6 @@ def test_wrong_aad_fails_decryption(self): with pytest.raises(InvalidTag): aes_gcm_decrypt(key, nonce, ciphertext, b"wrong aad") - def test_invalid_key_length_raises(self): - """Test that invalid key length raises ValueError.""" - with pytest.raises(ValueError, match="Key must be 16 bytes"): - aes_gcm_encrypt(bytes(15), bytes(12), b"data", b"") # type: ignore[arg-type] - - def test_invalid_nonce_length_raises(self): - """Test that invalid nonce length raises ValueError.""" - with pytest.raises(ValueError, match="Nonce must be 12 bytes"): - aes_gcm_encrypt(bytes(16), bytes(11), b"data", b"") # type: ignore[arg-type] - class TestEcdh: """Tests for secp256k1 ECDH key agreement.""" @@ -292,18 +272,6 @@ def test_zero_private_key_rejected(self): Bytes32.zero(), ) - def test_wrong_length_dest_node_id(self): - """Signing rejects non-32-byte destination node ID.""" - priv, _ = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - with pytest.raises((ValueError, TypeError)): - sign_id_nonce( - priv, - make_challenge_data(), - eph_pub, - bytes(16), # type: ignore[arg-type] - ) - class TestVerifyIdNonceNegativeCases: """Negative tests for ID nonce signature verification.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py index 63c19809..e63d1fc0 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py @@ -16,7 +16,7 @@ PendingHandshake, ) from lean_spec.subspecs.networking.discovery.keys import compute_node_id -from lean_spec.subspecs.networking.discovery.messages import IdNonce +from lean_spec.subspecs.networking.discovery.messages import IdNonce, Nonce from lean_spec.subspecs.networking.discovery.packet import ( HandshakeAuthdata, WhoAreYouAuthdata, @@ -27,7 +27,7 @@ from lean_spec.subspecs.networking.discovery.session import SessionCache from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes32, Bytes33, Bytes64 +from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes64 from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY @@ -143,9 +143,9 @@ def test_cancel_nonexistent_returns_false(self, manager): def test_create_whoareyou(self, manager): """Test creating a WHOAREYOU challenge.""" remote_node_id = bytes(32) - request_nonce = bytes(12) + request_nonce = Nonce(bytes(12)) remote_enr_seq = SeqNumber(0) - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id, request_nonce, remote_enr_seq, masking_iv @@ -189,28 +189,6 @@ def test_cleanup_expired(self, manager): assert manager.get_pending(remote1) is None assert manager.get_pending(remote2) is None - def test_invalid_local_node_id_raises(self): - """Test that invalid local node ID raises ValueError.""" - with pytest.raises(ValueError, match="Local node ID must be 32 bytes"): - HandshakeManager( - local_node_id=bytes(31), # type: ignore[arg-type] - local_private_key=bytes(32), - local_enr_rlp=b"enr", - local_enr_seq=SeqNumber(1), - session_cache=SessionCache(), - ) - - def test_invalid_local_private_key_raises(self): - """Test that invalid local private key raises ValueError.""" - with pytest.raises(ValueError, match="Local private key must be 32 bytes"): - HandshakeManager( - local_node_id=NodeId(bytes(32)), - local_private_key=bytes(31), - local_enr_rlp=b"enr", - local_enr_seq=SeqNumber(1), - session_cache=SessionCache(), - ) - class TestHandshakeState: """Tests for HandshakeState enum.""" @@ -270,9 +248,9 @@ def test_create_whoareyou_transitions_to_sent_whoareyou(self, manager): When we receive an undecryptable MESSAGE, we respond with WHOAREYOU. """ remote_node_id = bytes(32) - request_nonce = bytes(12) + request_nonce = Nonce(bytes(12)) remote_enr_seq = SeqNumber(0) - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id, request_nonce, remote_enr_seq, masking_iv @@ -288,9 +266,9 @@ def test_create_whoareyou_transitions_to_sent_whoareyou(self, manager): def test_sent_whoareyou_state_has_challenge_data(self, manager): """In SENT_WHOAREYOU state, all challenge data is stored.""" remote_node_id = bytes(32) - request_nonce = bytes(12) + request_nonce = Nonce(bytes(12)) remote_enr_seq = SeqNumber(5) - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) manager.create_whoareyou(remote_node_id, request_nonce, remote_enr_seq, masking_iv) @@ -330,8 +308,8 @@ def test_handle_handshake_requires_pending_state(self, manager, remote_keypair): src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, ) @@ -350,8 +328,8 @@ def test_handle_handshake_requires_sent_whoareyou_state(self, manager, remote_ke src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, ) @@ -366,9 +344,9 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) # Set up WHOAREYOU state. manager.create_whoareyou( NodeId(remote_node_id), - bytes(12), + Nonce(bytes(12)), SeqNumber(0), - bytes(16), + Bytes16(bytes(16)), ) # Create authdata with different src_id. @@ -377,8 +355,8 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) src_id=wrong_src_id, sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, ) @@ -397,9 +375,9 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa # Set up WHOAREYOU with enr_seq=0 (unknown). manager.create_whoareyou( NodeId(remote_node_id), - bytes(12), + Nonce(bytes(12)), SeqNumber(0), # enr_seq = 0 means we don't know remote's ENR - bytes(16), + Bytes16(bytes(16)), ) # Create authdata without ENR record. @@ -407,8 +385,8 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, # No ENR included. ) @@ -426,9 +404,9 @@ def test_successful_handshake_with_signature_verification( remote_priv, remote_pub, remote_node_id = remote_keypair # Node A (manager) creates WHOAREYOU for remote. - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - NodeId(remote_node_id), bytes(12), SeqNumber(0), masking_iv + NodeId(remote_node_id), Nonce(bytes(12)), SeqNumber(0), masking_iv ) # Remote creates handshake response. @@ -475,8 +453,8 @@ def test_handle_handshake_rejects_invalid_signature( remote_priv, remote_pub, remote_node_id = remote_keypair # Set up WHOAREYOU state. - masking_iv = bytes(16) - manager.create_whoareyou(NodeId(remote_node_id), bytes(12), SeqNumber(0), masking_iv) + masking_iv = Bytes16(bytes(16)) + manager.create_whoareyou(NodeId(remote_node_id), Nonce(bytes(12)), SeqNumber(0), masking_iv) # Generate ephemeral key. _eph_priv, eph_pub = generate_secp256k1_keypair() @@ -490,7 +468,7 @@ def test_handle_handshake_rejects_invalid_signature( authdata_bytes = encode_handshake_authdata( src_id=NodeId(remote_node_id), - id_signature=bytes(64), # Wrong signature. + id_signature=Bytes64(bytes(64)), # Wrong signature. eph_pubkey=eph_pub, record=remote_enr.to_rlp(), ) @@ -515,7 +493,7 @@ def test_multiple_handshakes_independent(self, manager): manager.start_handshake(remote2) # Create WHOAREYOU for third remote. - manager.create_whoareyou(remote3, bytes(12), SeqNumber(0), bytes(16)) + manager.create_whoareyou(remote3, Nonce(bytes(12)), SeqNumber(0), Bytes16(bytes(16))) # All should have independent state. assert manager.get_pending(remote1).state == HandshakeState.SENT_ORDINARY @@ -581,8 +559,10 @@ def test_id_nonce_uniqueness_across_challenges(self, manager): remote1 = bytes.fromhex("01" + "00" * 31) remote2 = bytes.fromhex("02" + "00" * 31) - id_nonce1, _, _, _ = manager.create_whoareyou(remote1, bytes(12), SeqNumber(0), bytes(16)) - id_nonce2, _, _, _ = manager.create_whoareyou(remote2, bytes(12), SeqNumber(0), bytes(16)) + nonce = Nonce(bytes(12)) + iv = Bytes16(bytes(16)) + id_nonce1, _, _, _ = manager.create_whoareyou(remote1, nonce, SeqNumber(0), iv) + id_nonce2, _, _, _ = manager.create_whoareyou(remote2, nonce, SeqNumber(0), iv) # Each challenge should have unique id_nonce. assert id_nonce1 != id_nonce2 @@ -619,7 +599,7 @@ def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypa authdata, _, _ = manager.create_handshake_response( remote_node_id=NodeId(remote_node_id), whoareyou=whoareyou, - remote_pubkey=bytes(remote_pub), + remote_pubkey=remote_pub, challenge_data=challenge_data, ) @@ -655,7 +635,7 @@ def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_key authdata, _, _ = manager.create_handshake_response( remote_node_id=NodeId(remote_node_id), whoareyou=whoareyou, - remote_pubkey=bytes(remote_pub), + remote_pubkey=remote_pub, challenge_data=challenge_data, ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py index bff0d39c..3bc5816b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_integration.py @@ -21,6 +21,7 @@ from lean_spec.subspecs.networking.discovery.handshake import HandshakeManager from lean_spec.subspecs.networking.discovery.keys import compute_node_id, derive_keys_from_pubkey from lean_spec.subspecs.networking.discovery.messages import ( + Nonce, PacketFlag, Ping, RequestId, @@ -41,7 +42,7 @@ from lean_spec.subspecs.networking.discovery.session import Session, SessionCache from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64 +from lean_spec.types import Bytes12, Bytes16, Bytes64 @pytest.fixture @@ -49,7 +50,7 @@ def node_a_keys(): """Node A's keypair.""" priv, pub = generate_secp256k1_keypair() node_id = compute_node_id(pub) - return {"private_key": priv, "public_key": pub, "node_id": bytes(node_id)} + return {"private_key": priv, "public_key": pub, "node_id": NodeId(node_id)} @pytest.fixture @@ -57,7 +58,7 @@ def node_b_keys(): """Node B's keypair.""" priv, pub = generate_secp256k1_keypair() node_id = compute_node_id(pub) - return {"private_key": priv, "public_key": pub, "node_id": bytes(node_id)} + return {"private_key": priv, "public_key": pub, "node_id": NodeId(node_id)} class TestEncryptedPacketRoundtrip: @@ -67,7 +68,7 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): """MESSAGE packet encrypts and decrypts correctly.""" # Build mock challenge_data for key derivation. # Format: masking-iv (16) + static-header (23) + authdata (24) = 63 bytes. - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) static_header = b"discv5" + b"\x00\x01\x01" + bytes(12) + b"\x00\x18" authdata = bytes(24) challenge_data = masking_iv + static_header + authdata @@ -75,10 +76,10 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): # Create session keys (derived from ECDH). # Node A is initiator. send_key, recv_key = derive_keys_from_pubkey( - local_private_key=Bytes32(node_a_keys["private_key"]), + local_private_key=node_a_keys["private_key"], remote_public_key=node_b_keys["public_key"], - local_node_id=Bytes32(node_a_keys["node_id"]), - remote_node_id=Bytes32(node_b_keys["node_id"]), + local_node_id=node_a_keys["node_id"], + remote_node_id=node_b_keys["node_id"], challenge_data=challenge_data, is_initiator=True, ) @@ -98,7 +99,7 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): packet = encode_packet( dest_node_id=node_b_keys["node_id"], flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=send_key, @@ -115,10 +116,10 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): # Node B derives keys as recipient (using same challenge_data). b_send_key, b_recv_key = derive_keys_from_pubkey( - local_private_key=Bytes32(node_b_keys["private_key"]), + local_private_key=node_b_keys["private_key"], remote_public_key=node_a_keys["public_key"], - local_node_id=Bytes32(node_b_keys["node_id"]), - remote_node_id=Bytes32(node_a_keys["node_id"]), + local_node_id=node_b_keys["node_id"], + remote_node_id=node_a_keys["node_id"], challenge_data=challenge_data, is_initiator=False, ) @@ -145,8 +146,8 @@ def test_session_cache_operations(self, node_a_keys, node_b_keys): now = time.time() session = Session( node_id=node_b_keys["node_id"], - send_key=bytes(16), - recv_key=bytes(16), + send_key=Bytes16(bytes(16)), + recv_key=Bytes16(bytes(16)), created_at=now, last_seen=now, is_initiator=True, @@ -231,8 +232,8 @@ def test_whoareyou_generation(self, node_a_keys, node_b_keys): ) # Create WHOAREYOU. - request_nonce = bytes(12) - masking_iv = bytes(16) + request_nonce = Nonce(bytes(12)) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id=node_b_keys["node_id"], request_nonce=request_nonce, @@ -327,8 +328,8 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): manager_a.start_handshake(node_b_keys["node_id"]) # Step 2: Node B creates WHOAREYOU. - request_nonce = bytes(12) - masking_iv = bytes(16) + request_nonce = Nonce(bytes(12)) + masking_iv = Bytes16(bytes(16)) id_nonce, whoareyou_authdata, _, challenge_data = manager_b.create_whoareyou( remote_node_id=node_a_keys["node_id"], request_nonce=request_nonce, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_keys.py b/tests/lean_spec/subspecs/networking/discovery/test_keys.py index b30bc151..038fd31a 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_keys.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_keys.py @@ -1,7 +1,5 @@ """Tests for Discovery v5 key derivation.""" -import pytest - from lean_spec.subspecs.networking.discovery.crypto import ( generate_secp256k1_keypair, pubkey_to_uncompressed, @@ -81,21 +79,6 @@ def test_order_matters(self): assert keys_ab != keys_ba - def test_invalid_secret_length_raises(self): - """Test that invalid secret length raises ValueError.""" - with pytest.raises(ValueError, match="Secret must be 33 bytes"): - derive_keys(bytes(32), bytes(32), bytes(32), make_challenge_data()) # type: ignore[arg-type] - - def test_invalid_initiator_id_length_raises(self): - """Test that invalid initiator ID length raises ValueError.""" - with pytest.raises(ValueError, match="Initiator ID must be 32 bytes"): - derive_keys(bytes(33), bytes(31), bytes(32), make_challenge_data()) # type: ignore[arg-type] - - def test_invalid_recipient_id_length_raises(self): - """Test that invalid recipient ID length raises ValueError.""" - with pytest.raises(ValueError, match="Recipient ID must be 32 bytes"): - derive_keys(bytes(33), bytes(32), bytes(31), make_challenge_data()) # type: ignore[arg-type] - class TestDeriveKeysFromPubkey: """Tests for key derivation from ECDH.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_packet.py b/tests/lean_spec/subspecs/networking/discovery/test_packet.py index 7ea230dc..30495bae 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_packet.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_packet.py @@ -4,7 +4,7 @@ from lean_spec.subspecs.networking.discovery.config import MAX_PACKET_SIZE, MIN_PACKET_SIZE from lean_spec.subspecs.networking.discovery.crypto import aes_ctr_encrypt -from lean_spec.subspecs.networking.discovery.messages import PacketFlag +from lean_spec.subspecs.networking.discovery.messages import IdNonce, Nonce, PacketFlag from lean_spec.subspecs.networking.discovery.packet import ( HANDSHAKE_HEADER_SIZE, MESSAGE_AUTHDATA_SIZE, @@ -22,7 +22,7 @@ generate_nonce, ) from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16 +from lean_spec.types import Bytes16, Bytes33, Bytes64 class TestNonceGeneration: @@ -75,7 +75,7 @@ class TestWhoAreYouAuthdata: def test_encode_whoareyou_authdata(self): """Test WHOAREYOU authdata encoding.""" - id_nonce = bytes(16) + id_nonce = IdNonce(bytes(16)) enr_seq = SeqNumber(42) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -84,24 +84,24 @@ def test_encode_whoareyou_authdata(self): def test_decode_whoareyou_authdata(self): """Test WHOAREYOU authdata decoding.""" - id_nonce = bytes.fromhex("aa" * 16) + id_nonce = IdNonce(bytes.fromhex("aa" * 16)) enr_seq = SeqNumber(12345) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce + assert decoded.id_nonce == id_nonce assert decoded.enr_seq == enr_seq def test_roundtrip(self): """Test encoding then decoding preserves values.""" - id_nonce = bytes.fromhex("01" * 16) + id_nonce = IdNonce(bytes.fromhex("01" * 16)) enr_seq = SeqNumber(2**63 - 1) # Max uint64 authdata = encode_whoareyou_authdata(id_nonce, enr_seq) decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce + assert decoded.id_nonce == id_nonce assert decoded.enr_seq == enr_seq def test_invalid_size_raises(self): @@ -116,8 +116,8 @@ class TestHandshakeAuthdata: def test_encode_handshake_authdata(self): """Test HANDSHAKE authdata encoding.""" src_id = NodeId(bytes(32)) - id_signature = bytes(64) - eph_pubkey = bytes([0x02]) + bytes(32) # Compressed pubkey format + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33(bytes([0x02]) + bytes(32)) authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey) @@ -128,8 +128,8 @@ def test_encode_handshake_authdata(self): def test_decode_handshake_authdata(self): """Test HANDSHAKE authdata decoding.""" src_id = NodeId(bytes.fromhex("aa" * 32)) - id_signature = bytes.fromhex("bb" * 64) - eph_pubkey = bytes([0x02]) + bytes.fromhex("cc" * 32) + id_signature = Bytes64(bytes.fromhex("bb" * 64)) + eph_pubkey = Bytes33(bytes([0x02]) + bytes.fromhex("cc" * 32)) authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey) decoded = decode_handshake_authdata(authdata) @@ -144,8 +144,8 @@ def test_decode_handshake_authdata(self): def test_with_enr_record(self): """Test HANDSHAKE authdata with ENR record.""" src_id = NodeId(bytes(32)) - id_signature = bytes(64) - eph_pubkey = bytes([0x02]) + bytes(32) + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33(bytes([0x02]) + bytes(32)) record = b"enr:-IS4QHCYrY..." # Mock ENR authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey, record) @@ -153,25 +153,6 @@ def test_with_enr_record(self): assert decoded.record == record - def test_invalid_src_id_length_raises(self): - """Test that invalid src_id length raises ValueError.""" - with pytest.raises(ValueError, match="Source ID must be 32 bytes"): - encode_handshake_authdata( - bytes(31), # type: ignore[arg-type] - bytes(64), - bytes(33), - ) - - def test_invalid_signature_length_raises(self): - """Test that invalid signature length raises ValueError.""" - with pytest.raises(ValueError, match="Signature must be 64 bytes"): - encode_handshake_authdata(NodeId(bytes(32)), bytes(63), bytes(33)) - - def test_invalid_eph_pubkey_length_raises(self): - """Test that invalid ephemeral pubkey length raises ValueError.""" - with pytest.raises(ValueError, match="Ephemeral pubkey must be 33 bytes"): - encode_handshake_authdata(NodeId(bytes(32)), bytes(64), bytes(32)) - class TestPacketEncoding: """Tests for full packet encoding/decoding.""" @@ -180,10 +161,10 @@ def test_encode_message_packet(self): """Test MESSAGE packet encoding.""" dest_node_id = NodeId(bytes(32)) src_node_id = NodeId(bytes(32)) - nonce = bytes(12) + nonce = Nonce(bytes(12)) authdata = encode_message_authdata(src_node_id) message = b"encrypted message" - encryption_key = bytes(16) + encryption_key = Bytes16(bytes(16)) packet = encode_packet( dest_node_id=dest_node_id, @@ -200,8 +181,8 @@ def test_encode_message_packet(self): def test_encode_whoareyou_packet(self): """Test WHOAREYOU packet encoding.""" dest_node_id = NodeId(bytes(32)) - nonce = bytes(12) - id_nonce = bytes(16) + nonce = Nonce(bytes(12)) + id_nonce = IdNonce(bytes(16)) authdata = encode_whoareyou_authdata(id_nonce, SeqNumber(0)) packet = encode_packet( @@ -210,7 +191,7 @@ def test_encode_whoareyou_packet(self): nonce=nonce, authdata=authdata, message=b"", - encryption_key=None, # WHOAREYOU doesn't encrypt + encryption_key=None, ) # WHOAREYOU has no message content @@ -220,8 +201,8 @@ def test_encode_whoareyou_packet(self): def test_decode_packet_header(self): """Test packet header decoding.""" local_node_id = NodeId(bytes(32)) - nonce = bytes(12) - authdata = encode_whoareyou_authdata(bytes(16), SeqNumber(42)) + nonce = Nonce(bytes(12)) + authdata = encode_whoareyou_authdata(IdNonce(bytes(16)), SeqNumber(42)) packet = encode_packet( dest_node_id=local_node_id, @@ -235,34 +216,10 @@ def test_decode_packet_header(self): header, message_bytes, _message_ad = decode_packet_header(local_node_id, packet) assert header.flag == PacketFlag.WHOAREYOU - assert bytes(header.nonce) == nonce + assert header.nonce == nonce assert header.authdata == authdata assert message_bytes == b"" - def test_invalid_dest_node_id_length_raises(self): - """Test that invalid dest_node_id length raises ValueError.""" - with pytest.raises(ValueError, match="Destination node ID must be 32 bytes"): - encode_packet( - dest_node_id=bytes(31), # type: ignore[arg-type] - flag=PacketFlag.MESSAGE, - nonce=bytes(12), - authdata=bytes(32), - message=b"", - encryption_key=bytes(16), - ) - - def test_invalid_nonce_length_raises(self): - """Test that invalid nonce length raises ValueError.""" - with pytest.raises(ValueError, match="Nonce must be 12 bytes"): - encode_packet( - dest_node_id=NodeId(bytes(32)), - flag=PacketFlag.MESSAGE, - nonce=bytes(11), - authdata=bytes(32), - message=b"", - encryption_key=bytes(16), - ) - class TestConstants: """Tests for packet format constants.""" @@ -327,8 +284,8 @@ def test_encode_packet_enforces_max_size(self): """encode_packet raises error if packet exceeds max size.""" src_id = NodeId(bytes(32)) dest_id = NodeId(bytes(32)) - nonce = bytes(12) - encryption_key = bytes(16) + nonce = Nonce(bytes(12)) + encryption_key = Bytes16(bytes(16)) # Create authdata. authdata = encode_message_authdata(src_id) @@ -391,7 +348,7 @@ def test_message_flag_without_encryption_key_raises(self): encode_packet( dest_node_id=NodeId(bytes(32)), flag=PacketFlag.MESSAGE, - nonce=bytes(12), + nonce=Nonce(bytes(12)), authdata=encode_message_authdata(NodeId(bytes(32))), message=b"\x01\xc2\x01\x01", encryption_key=None, @@ -401,35 +358,21 @@ def test_handshake_flag_without_encryption_key_raises(self): """HANDSHAKE packets require an encryption key.""" authdata = encode_handshake_authdata( src_id=NodeId(bytes(32)), - id_signature=bytes(64), - eph_pubkey=bytes([0x02]) + bytes(32), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes([0x02]) + bytes(32)), ) with pytest.raises(ValueError, match="Encryption key required"): encode_packet( dest_node_id=NodeId(bytes(32)), flag=PacketFlag.HANDSHAKE, - nonce=bytes(12), + nonce=Nonce(bytes(12)), authdata=authdata, message=b"\x01\xc2\x01\x01", encryption_key=None, ) -class TestAuthdataInvalidLengths: - """Edge cases for authdata encoding with invalid input lengths.""" - - def test_encode_whoareyou_authdata_wrong_id_nonce_length(self): - """WHOAREYOU authdata rejects id_nonce that is not 16 bytes.""" - with pytest.raises(ValueError, match="ID nonce must be 16 bytes"): - encode_whoareyou_authdata(bytes(15), SeqNumber(0)) - - def test_encode_message_authdata_wrong_src_id_length(self): - """MESSAGE authdata rejects src_id that is not 32 bytes.""" - with pytest.raises(ValueError, match="Source ID must be 32 bytes"): - encode_message_authdata(bytes(31)) # type: ignore[arg-type] - - class TestPacketProtocolValidation: """Protocol ID and version validation in packet decoding.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_routing.py b/tests/lean_spec/subspecs/networking/discovery/test_routing.py index dff72956..1ea5578b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_routing.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_routing.py @@ -19,9 +19,8 @@ ) from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH -from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber from lean_spec.types import Bytes64 -from lean_spec.types.byte_arrays import Bytes4 class TestXorDistance: @@ -447,7 +446,7 @@ def test_fork_filter_rejects_without_enr(self, local_node_id, remote_node_id): """With fork filter, nodes without ENR are rejected.""" table = RoutingTable( local_id=local_node_id, - local_fork_digest=Bytes4(bytes(4)), + local_fork_digest=ForkDigest(bytes(4)), ) entry = NodeEntry(node_id=remote_node_id, enr=None) @@ -457,7 +456,7 @@ def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_ """Nodes without eth2 data are rejected when filtering.""" table = RoutingTable( local_id=local_node_id, - local_fork_digest=Bytes4(bytes(4)), + local_fork_digest=ForkDigest(bytes(4)), ) enr = ENR( @@ -472,7 +471,7 @@ def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_ def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id): """Node with different fork_digest is rejected.""" - local_fork = Bytes4(bytes.fromhex("12345678")) + local_fork = ForkDigest(bytes.fromhex("12345678")) table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) # Build eth2 bytes with a different fork digest. @@ -491,7 +490,7 @@ def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): """Node with matching fork_digest is accepted.""" - local_fork = Bytes4(bytes.fromhex("12345678")) + local_fork = ForkDigest(bytes.fromhex("12345678")) table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) # Build eth2 bytes with the same fork digest. @@ -513,7 +512,7 @@ def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): def test_is_fork_compatible_method(self, local_node_id): """Verify is_fork_compatible for compatible, incompatible, and no-ENR entries.""" - local_fork = Bytes4(bytes.fromhex("12345678")) + local_fork = ForkDigest(bytes.fromhex("12345678")) table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) # Compatible entry. diff --git a/tests/lean_spec/subspecs/networking/discovery/test_session.py b/tests/lean_spec/subspecs/networking/discovery/test_session.py index bda15a64..9e4fb3fb 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_session.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_session.py @@ -2,8 +2,6 @@ import time -import pytest - from lean_spec.subspecs.networking.discovery.messages import Port from lean_spec.subspecs.networking.discovery.session import ( BondCache, @@ -11,6 +9,9 @@ SessionCache, ) from lean_spec.subspecs.networking.types import NodeId +from lean_spec.types import Bytes16 + +ZERO_KEY = Bytes16(bytes(16)) class TestSession: @@ -20,8 +21,8 @@ def test_create_session(self): """Test session creation.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time(), last_seen=time.time(), is_initiator=True, @@ -35,8 +36,8 @@ def test_is_expired_false_for_new_session(self): """Test that new session is not expired.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time(), last_seen=time.time(), is_initiator=True, @@ -48,8 +49,8 @@ def test_is_expired_true_for_old_session(self): """Test that old session is expired.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time() - 7200, # 2 hours ago last_seen=time.time() - 7200, is_initiator=True, @@ -61,8 +62,8 @@ def test_touch_updates_last_seen(self): """Test that touch updates last_seen timestamp.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time() - 100, last_seen=time.time() - 100, is_initiator=True, @@ -81,8 +82,8 @@ def test_create_and_get_session(self): """Test creating and retrieving a session.""" cache = SessionCache() node_id = NodeId(bytes.fromhex("aa" * 32)) - send_key = bytes(16) - recv_key = bytes(16) + send_key = ZERO_KEY + recv_key = ZERO_KEY session = cache.create(node_id, send_key, recv_key, is_initiator=True) @@ -101,7 +102,7 @@ def test_get_expired_returns_none(self): cache = SessionCache(timeout_secs=0.001) node_id = NodeId(bytes(32)) - cache.create(node_id, bytes(16), bytes(16), is_initiator=True) + cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) time.sleep(0.01) assert cache.get(node_id) is None @@ -112,7 +113,7 @@ def test_remove_session(self): cache = SessionCache() node_id = NodeId(bytes(32)) - cache.create(node_id, bytes(16), bytes(16), is_initiator=True) + cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.remove(node_id) assert cache.get(node_id) is None @@ -126,7 +127,7 @@ def test_touch_updates_session(self): cache = SessionCache() node_id = NodeId(bytes(32)) - cache.create(node_id, bytes(16), bytes(16), is_initiator=True) + cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.touch(node_id) def test_touch_nonexistent_returns_false(self): @@ -140,18 +141,18 @@ def test_count(self): assert cache.count() == 0 - cache.create(NodeId(bytes.fromhex("aa" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("aa" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 1 - cache.create(NodeId(bytes.fromhex("bb" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("bb" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 2 def test_cleanup_expired(self): """Test expired session cleanup.""" cache = SessionCache(timeout_secs=0.001) - cache.create(NodeId(bytes.fromhex("aa" * 32)), bytes(16), bytes(16), is_initiator=True) - cache.create(NodeId(bytes.fromhex("bb" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("aa" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) + cache.create(NodeId(bytes.fromhex("bb" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) time.sleep(0.01) removed = cache.cleanup_expired() @@ -166,36 +167,20 @@ def test_eviction_when_full(self): node2 = NodeId(bytes.fromhex("02" + "00" * 31)) node3 = NodeId(bytes.fromhex("03" + "00" * 31)) - cache.create(node1, bytes(16), bytes(16), is_initiator=True) + cache.create(node1, ZERO_KEY, ZERO_KEY, is_initiator=True) time.sleep(0.01) # Ensure different timestamps - cache.create(node2, bytes(16), bytes(16), is_initiator=True) + cache.create(node2, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 2 # Adding third should evict first - cache.create(node3, bytes(16), bytes(16), is_initiator=True) + cache.create(node3, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 2 assert cache.get(node1) is None # Oldest should be evicted assert cache.get(node2) is not None assert cache.get(node3) is not None - def test_invalid_node_id_length_raises(self): - """Test that invalid node ID length raises ValueError.""" - cache = SessionCache() - with pytest.raises(ValueError, match="Node ID must be 32 bytes"): - cache.create(bytes(31), bytes(16), bytes(16), is_initiator=True) # type: ignore[arg-type] - - def test_invalid_key_length_raises(self): - """Test that invalid key lengths raise ValueError.""" - cache = SessionCache() - - with pytest.raises(ValueError, match="Send key must be 16 bytes"): - cache.create(NodeId(bytes(32)), bytes(15), bytes(16), is_initiator=True) - - with pytest.raises(ValueError, match="Recv key must be 16 bytes"): - cache.create(NodeId(bytes(32)), bytes(16), bytes(15), is_initiator=True) - def test_endpoint_keying_separates_sessions(self): """Same node_id at different ip:port has separate sessions. @@ -204,15 +189,15 @@ def test_endpoint_keying_separates_sessions(self): """ cache = SessionCache() node_id = NodeId(bytes.fromhex("aa" * 32)) - send_key_1 = bytes([0x01] * 16) - send_key_2 = bytes([0x02] * 16) + send_key_1 = Bytes16(bytes([0x01] * 16)) + send_key_2 = Bytes16(bytes([0x02] * 16)) # Create sessions for same node at different endpoints. cache.create( - node_id, send_key_1, bytes(16), is_initiator=True, ip="10.0.0.1", port=Port(9000) + node_id, send_key_1, ZERO_KEY, is_initiator=True, ip="10.0.0.1", port=Port(9000) ) cache.create( - node_id, send_key_2, bytes(16), is_initiator=True, ip="10.0.0.2", port=Port(9000) + node_id, send_key_2, ZERO_KEY, is_initiator=True, ip="10.0.0.2", port=Port(9000) ) # Each endpoint retrieves its own session. diff --git a/tests/lean_spec/subspecs/networking/discovery/test_transport.py b/tests/lean_spec/subspecs/networking/discovery/test_transport.py index 95f6a613..223f29f5 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_transport.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_transport.py @@ -279,7 +279,7 @@ async def test_stop_cancels_pending_requests(self, local_node_id, local_private_ request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -318,7 +318,7 @@ def test_create_pending_request(self): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -410,7 +410,7 @@ def test_pending_multi_request_creation(self, local_node_id, local_private_key, request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=None, @@ -433,7 +433,7 @@ def test_pending_multi_request_expected_total_tracking(self): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=None, @@ -469,7 +469,7 @@ async def test_queue(): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=3, @@ -575,7 +575,7 @@ def test_pending_request_stores_request_id(self): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -599,7 +599,7 @@ async def test_future(): request_id=b"\x01", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -635,7 +635,7 @@ def test_pending_request_future_cancellation(self): request_id=b"\x01", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -666,7 +666,7 @@ def test_request_id_bytes_for_dict_lookup(self): request_id=request_id_1, dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message1, future=future1, ) @@ -675,7 +675,7 @@ def test_request_id_bytes_for_dict_lookup(self): request_id=request_id_2, dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message2, future=future2, ) @@ -699,7 +699,7 @@ def test_pending_request_stores_nonce_for_whoareyou_matching(self): loop = asyncio.new_event_loop() future: asyncio.Future = loop.create_future() - nonce = b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c" + nonce = Nonce(b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c") message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) pending = PendingRequest( @@ -727,7 +727,7 @@ def test_pending_request_stores_message_for_retransmission(self): request_id=b"\x01", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -785,7 +785,7 @@ async def test_pending_requests_cleared_on_stop( request_id=bytes([i]), dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -828,7 +828,7 @@ async def test_pending_request_futures_cancelled_on_stop( request_id=bytes([i]), dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -1092,7 +1092,7 @@ async def test_response_completes_pending_request_future( request_id=request_id, dest_node_id=remote_node_id, sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -1128,7 +1128,7 @@ async def test_response_enqueued_for_multi_request( request_id=request_id, dest_node_id=remote_node_id, sent_at=0.0, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=None, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py index 6b9c9a20..51f34e6d 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py @@ -29,9 +29,11 @@ from lean_spec.subspecs.networking.discovery.messages import ( Distance, FindNode, + IdNonce, IPv4, MessageType, Nodes, + Nonce, PacketFlag, Ping, Pong, @@ -53,7 +55,7 @@ encode_whoareyou_authdata, ) from lean_spec.subspecs.networking.discovery.routing import log2_distance, xor_distance -from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.subspecs.networking.types import SeqNumber from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64 from lean_spec.types.uint import Uint8 from tests.lean_spec.helpers import make_challenge_data @@ -67,27 +69,27 @@ ) # Spec test vector values for ECDH and key derivation. -SPEC_NONCE = bytes.fromhex("0102030405060708090a0b0c") +SPEC_NONCE = Nonce(bytes.fromhex("0102030405060708090a0b0c")) SPEC_CHALLENGE_DATA = bytes.fromhex( "000000000000000000000000000000006469736376350001010102030405060708090a0b0c" "00180102030405060708090a0b0c0d0e0f100000000000000000" ) # Spec ephemeral keypair for ECDH / ID nonce signing. -SPEC_EPHEMERAL_KEY = bytes.fromhex( - "fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" +SPEC_EPHEMERAL_KEY = Bytes32( + bytes.fromhex("fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736") ) -SPEC_EPHEMERAL_PUBKEY = bytes.fromhex( - "039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" +SPEC_EPHEMERAL_PUBKEY = Bytes33( + bytes.fromhex("039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231") ) # Derived session keys from spec HKDF test vector. -SPEC_INITIATOR_KEY = bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571") -SPEC_RECIPIENT_KEY = bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5") +SPEC_INITIATOR_KEY = Bytes16(bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571")) +SPEC_RECIPIENT_KEY = Bytes16(bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5")) # AES-GCM test vector values. -SPEC_AES_KEY = bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b") -SPEC_AES_NONCE = bytes.fromhex("27b5af763c446acd2749fe8e") +SPEC_AES_KEY = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) +SPEC_AES_NONCE = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) # PING message plaintext (type 0x01, RLP [1]). SPEC_PING_PLAINTEXT = bytes.fromhex("01c20101") @@ -108,9 +110,11 @@ def test_node_b_id_from_privkey(self): int.from_bytes(NODE_B_PRIVKEY, "big"), ec.SECP256K1(), ) - pubkey_bytes = private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + pubkey_bytes = Bytes33( + private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) computed = compute_node_id(pubkey_bytes) @@ -126,9 +130,11 @@ def test_node_a_id_from_privkey(self): int.from_bytes(NODE_A_PRIVKEY, "big"), ec.SECP256K1(), ) - pubkey_bytes = private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + pubkey_bytes = Bytes33( + private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) computed = compute_node_id(pubkey_bytes) assert bytes(computed) == NODE_A_ID @@ -144,13 +150,15 @@ def test_bidirectional_ecdh(self): int.from_bytes(NODE_A_PRIVKEY, "big"), ec.SECP256K1(), ) - a_pubkey_bytes = a_privkey.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + a_pubkey_bytes = Bytes33( + a_privkey.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) - shared_ab = ecdh_agree(Bytes32(NODE_A_PRIVKEY), NODE_B_PUBKEY) - shared_ba = ecdh_agree(Bytes32(NODE_B_PRIVKEY), a_pubkey_bytes) + shared_ab = ecdh_agree(NODE_A_PRIVKEY, NODE_B_PUBKEY) + shared_ba = ecdh_agree(NODE_B_PRIVKEY, a_pubkey_bytes) assert shared_ab == shared_ba @@ -168,7 +176,7 @@ def test_ecdh_shared_secret(self): "033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e" ) - shared = ecdh_agree(Bytes32(SPEC_EPHEMERAL_KEY), SPEC_EPHEMERAL_PUBKEY) + shared = ecdh_agree(SPEC_EPHEMERAL_KEY, SPEC_EPHEMERAL_PUBKEY) assert shared == expected_shared @@ -180,13 +188,13 @@ def test_key_derivation_hkdf(self): Uses exact spec challenge_data (with nonce 0102030405060708090a0b0c). """ # Compute ECDH shared secret. - shared_secret = ecdh_agree(Bytes32(SPEC_EPHEMERAL_KEY), NODE_B_PUBKEY) + shared_secret = ecdh_agree(SPEC_EPHEMERAL_KEY, NODE_B_PUBKEY) # Derive keys using exact spec challenge_data. initiator_key, recipient_key = derive_keys( secret=shared_secret, - initiator_id=Bytes32(NODE_A_ID), - recipient_id=Bytes32(NODE_B_ID), + initiator_id=NODE_A_ID, + recipient_id=NODE_B_ID, challenge_data=SPEC_CHALLENGE_DATA, ) @@ -206,10 +214,10 @@ def test_id_nonce_signature(self): """ # Sign using exact spec challenge_data. signature = sign_id_nonce( - private_key_bytes=Bytes32(SPEC_EPHEMERAL_KEY), + private_key_bytes=SPEC_EPHEMERAL_KEY, challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=Bytes33(SPEC_EPHEMERAL_PUBKEY), - dest_node_id=Bytes32(NODE_B_ID), + ephemeral_pubkey=SPEC_EPHEMERAL_PUBKEY, + dest_node_id=NODE_B_ID, ) expected_sig = bytes.fromhex( @@ -231,8 +239,8 @@ def test_id_nonce_signature(self): assert verify_id_nonce_signature( signature=Bytes64(signature), challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=Bytes33(SPEC_EPHEMERAL_PUBKEY), - dest_node_id=Bytes32(NODE_B_ID), + ephemeral_pubkey=SPEC_EPHEMERAL_PUBKEY, + dest_node_id=NODE_B_ID, public_key_bytes=Bytes33(pubkey_bytes), ) @@ -242,16 +250,16 @@ def test_id_nonce_signature_different_challenge_data(self): challenge_data2 = make_challenge_data(bytes([1]) + bytes(15)) sig1 = sign_id_nonce( - Bytes32(NODE_B_PRIVKEY), + NODE_B_PRIVKEY, challenge_data1, - Bytes33(SPEC_EPHEMERAL_PUBKEY), - Bytes32(NODE_A_ID), + SPEC_EPHEMERAL_PUBKEY, + NODE_A_ID, ) sig2 = sign_id_nonce( - Bytes32(NODE_B_PRIVKEY), + NODE_B_PRIVKEY, challenge_data2, - Bytes33(SPEC_EPHEMERAL_PUBKEY), - Bytes32(NODE_A_ID), + SPEC_EPHEMERAL_PUBKEY, + NODE_A_ID, ) assert sig1 != sig2 @@ -266,14 +274,12 @@ def test_aes_gcm_encryption(self): expected_ciphertext = bytes.fromhex("a5d12a2d94b8ccb3ba55558229867dc13bfa3648") # Encrypt. - ciphertext = aes_gcm_encrypt( - Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), SPEC_PING_PLAINTEXT, aad - ) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, SPEC_PING_PLAINTEXT, aad) assert ciphertext == expected_ciphertext # Verify decryption works. - decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) assert decrypted == SPEC_PING_PLAINTEXT @@ -302,8 +308,8 @@ def test_decode_spec_ping_packet(self): assert decoded_authdata.src_id == NODE_A_ID # Decrypt using the spec's read-key (all zeros for this test vector). - read_key = bytes(16) - plaintext = decrypt_message(read_key, bytes(header.nonce), ciphertext, message_ad) + read_key = Bytes16(bytes(16)) + plaintext = decrypt_message(read_key, header.nonce, ciphertext, message_ad) # PING with request-id=0x00000001 (4 bytes) and enr-seq=2. decoded = decode_message(plaintext) @@ -386,8 +392,8 @@ def test_decode_spec_handshake_with_enr_packet(self): assert len(decoded_authdata.record) > 0 # Decrypt the message using the spec's read-key. - read_key = bytes.fromhex("53b1c075f41876423154e157470c2f48") - plaintext = decrypt_message(read_key, bytes(header.nonce), ciphertext, message_ad) + read_key = Bytes16(bytes.fromhex("53b1c075f41876423154e157470c2f48")) + plaintext = decrypt_message(read_key, header.nonce, ciphertext, message_ad) # PING with request-id=0x00000001 and enr-seq=1. decoded = decode_message(plaintext) @@ -400,8 +406,8 @@ class TestPacketEncodingRoundtrip: def test_message_packet_roundtrip(self): """MESSAGE packet encodes and decodes correctly.""" - nonce = bytes(12) # 12-byte nonce. - encryption_key = bytes(16) # 16-byte key. + nonce = Nonce(bytes(12)) # 12-byte nonce. + encryption_key = Bytes16(bytes(16)) # 16-byte key. message = b"\x01\xc2\x01\x01" # PING message. authdata = encode_message_authdata(NODE_A_ID) @@ -426,8 +432,8 @@ def test_message_packet_roundtrip(self): def test_whoareyou_packet_roundtrip(self): """WHOAREYOU packet encodes and decodes correctly.""" - nonce = bytes.fromhex("0102030405060708090a0b0c") - id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") + nonce = Nonce(bytes.fromhex("0102030405060708090a0b0c")) + id_nonce = IdNonce(bytes.fromhex("0102030405060708090a0b0c0d0e0f10")) enr_seq = SeqNumber(0) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -453,12 +459,12 @@ def test_whoareyou_packet_roundtrip(self): def test_handshake_packet_roundtrip(self): """HANDSHAKE packet encodes and decodes correctly.""" - nonce = bytes(12) + nonce = Nonce(bytes(12)) message = b"\x01\xc2\x01\x01" # PING message. - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33( + bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") ) authdata = encode_handshake_authdata( @@ -587,8 +593,8 @@ def test_message_packet_header_structure(self): - nonce: 12 bytes - authdata-size: 2 bytes """ - nonce = bytes(12) - encryption_key = bytes(16) + nonce = Nonce(bytes(12)) + encryption_key = Bytes16(bytes(16)) message = b"\x01\xc2\x01\x01" authdata = encode_message_authdata(NODE_A_ID) @@ -613,8 +619,8 @@ def test_whoareyou_packet_header_structure(self): - authdata: id-nonce (16) + enr-seq (8) = 24 bytes - no message payload """ - nonce = bytes(12) - id_nonce = bytes(16) + nonce = Nonce(bytes(12)) + id_nonce = IdNonce(bytes(16)) enr_seq = SeqNumber(0) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -640,13 +646,13 @@ def test_handshake_packet_header_structure(self): - authdata: src-id (32) + sig-size (1) + eph-key-size (1) + sig + eph-key + [record] - encrypted message """ - nonce = bytes(12) - encryption_key = bytes(16) + nonce = Nonce(bytes(12)) + encryption_key = Bytes16(bytes(16)) message = b"\x01\xc2\x01\x01" - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33( + bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") ) authdata = encode_handshake_authdata( @@ -723,13 +729,13 @@ def test_aes_gcm_empty_plaintext(self): aad = bytes(32) plaintext = b"" - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) # Empty plaintext should produce just the 16-byte auth tag. assert len(ciphertext) == 16 # Decryption should recover empty plaintext. - decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) assert decrypted == b"" def test_aes_gcm_large_plaintext(self): @@ -737,13 +743,13 @@ def test_aes_gcm_large_plaintext(self): aad = bytes(32) plaintext = bytes(1024) # 1KB of zeros. - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) # Ciphertext = plaintext length + 16-byte tag. assert len(ciphertext) == len(plaintext) + 16 # Decryption should recover original plaintext. - decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) assert decrypted == plaintext def test_aes_gcm_tampered_ciphertext_fails(self): @@ -751,7 +757,7 @@ def test_aes_gcm_tampered_ciphertext_fails(self): aad = bytes(32) plaintext = b"secret message" - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) # Tamper with ciphertext by flipping a bit. tampered = bytearray(ciphertext) @@ -760,7 +766,7 @@ def test_aes_gcm_tampered_ciphertext_fails(self): # Decryption of tampered ciphertext should fail with InvalidTag. with pytest.raises(InvalidTag): - aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), tampered, aad) + aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, tampered, aad) class TestSpecPacketPayloadDecryption: @@ -768,7 +774,7 @@ class TestSpecPacketPayloadDecryption: def test_message_packet_encrypt_decrypt_roundtrip(self): """Encrypt a message in a packet and decrypt using message_ad from decode.""" - nonce = bytes(12) + nonce = Nonce(bytes(12)) authdata = encode_message_authdata(NODE_A_ID) @@ -785,16 +791,16 @@ def test_message_packet_encrypt_decrypt_roundtrip(self): header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) # Decrypt using message_ad as AAD. - decrypted = decrypt_message(SPEC_INITIATOR_KEY, bytes(header.nonce), ciphertext, message_ad) + decrypted = decrypt_message(SPEC_INITIATOR_KEY, header.nonce, ciphertext, message_ad) assert decrypted == SPEC_PING_PLAINTEXT def test_handshake_packet_encrypt_decrypt_roundtrip(self): """Handshake packet encrypts and decrypts using correct AAD.""" - nonce = bytes(12) + nonce = Nonce(bytes(12)) - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33( + bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") ) authdata = encode_handshake_authdata( @@ -817,7 +823,7 @@ def test_handshake_packet_encrypt_decrypt_roundtrip(self): header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) # Decrypt using message_ad as AAD. - decrypted = decrypt_message(SPEC_INITIATOR_KEY, bytes(header.nonce), ciphertext, message_ad) + decrypted = decrypt_message(SPEC_INITIATOR_KEY, header.nonce, ciphertext, message_ad) assert decrypted == SPEC_PING_PLAINTEXT @@ -826,17 +832,12 @@ class TestRoutingWithTestVectorNodeIds: def test_xor_distance_is_symmetric(self): """XOR distance between test vector nodes is symmetric and non-zero.""" - node_a = NodeId(NODE_A_ID) - node_b = NodeId(NODE_B_ID) - - distance = xor_distance(node_a, node_b) + distance = xor_distance(NODE_A_ID, NODE_B_ID) assert distance > 0 - assert xor_distance(node_a, node_b) == xor_distance(node_b, node_a) + assert xor_distance(NODE_A_ID, NODE_B_ID) == xor_distance(NODE_B_ID, NODE_A_ID) def test_log2_distance_is_high(self): """Log2 distance between test vector nodes is high (differ in high bits).""" - node_a = NodeId(NODE_A_ID) - node_b = NodeId(NODE_B_ID) - log_dist = log2_distance(node_a, node_b) + log_dist = log2_distance(NODE_A_ID, NODE_B_ID) assert log_dist > Distance(200) diff --git a/tests/lean_spec/subspecs/networking/enr/test_eth2.py b/tests/lean_spec/subspecs/networking/enr/test_eth2.py index fb437a79..28d945ed 100644 --- a/tests/lean_spec/subspecs/networking/enr/test_eth2.py +++ b/tests/lean_spec/subspecs/networking/enr/test_eth2.py @@ -10,8 +10,8 @@ AttestationSubnets, SyncCommitteeSubnets, ) +from lean_spec.subspecs.networking.types import ForkDigest, Version from lean_spec.types import Uint64 -from lean_spec.types.byte_arrays import Bytes4 class TestEth2Data: @@ -20,17 +20,17 @@ class TestEth2Data: def test_create_eth2_data(self) -> None: """Eth2Data can be created with valid parameters.""" data = Eth2Data( - fork_digest=Bytes4(b"\x12\x34\x56\x78"), - next_fork_version=Bytes4(b"\x02\x00\x00\x00"), + fork_digest=ForkDigest(b"\x12\x34\x56\x78"), + next_fork_version=Version(b"\x02\x00\x00\x00"), next_fork_epoch=Uint64(194048), ) - assert data.fork_digest == Bytes4(b"\x12\x34\x56\x78") + assert data.fork_digest == ForkDigest(b"\x12\x34\x56\x78") assert data.next_fork_epoch == Uint64(194048) def test_no_scheduled_fork_factory(self) -> None: """no_scheduled_fork factory creates correct data.""" - digest = Bytes4(b"\xab\xcd\xef\x01") - version = Bytes4(b"\x01\x00\x00\x00") + digest = ForkDigest(b"\xab\xcd\xef\x01") + version = Version(b"\x01\x00\x00\x00") data = Eth2Data.no_scheduled_fork(digest, version) assert data.fork_digest == digest @@ -40,12 +40,12 @@ def test_no_scheduled_fork_factory(self) -> None: def test_eth2_data_immutable(self) -> None: """Eth2Data is immutable (frozen).""" data = Eth2Data( - fork_digest=Bytes4(b"\x12\x34\x56\x78"), - next_fork_version=Bytes4(b"\x02\x00\x00\x00"), + fork_digest=ForkDigest(b"\x12\x34\x56\x78"), + next_fork_version=Version(b"\x02\x00\x00\x00"), next_fork_epoch=Uint64(0), ) with pytest.raises(ValidationError): - data.fork_digest = Bytes4(b"\x00\x00\x00\x00") + data.fork_digest = ForkDigest(b"\x00\x00\x00\x00") def test_far_future_epoch_value(self) -> None: """FAR_FUTURE_EPOCH is max uint64.""" @@ -189,7 +189,7 @@ def test_from_subnet_ids_with_duplicates(self) -> None: """from_subnet_ids handles duplicates correctly.""" subnets = SyncCommitteeSubnets.from_subnet_ids([1, 1, 1, 3]) assert subnets.subscription_count() == 2 - assert subnets.subscribed_subnets() == [1, 3] + assert subnets.subscribed_subnets() == [SubnetId(1), SubnetId(3)] def test_from_subnet_ids_invalid(self) -> None: """from_subnet_ids() raises for invalid subnet IDs.""" @@ -202,7 +202,7 @@ def test_from_subnet_ids_invalid(self) -> None: def test_subscribed_subnets(self) -> None: """subscribed_subnets() returns correct list.""" subnets = SyncCommitteeSubnets.from_subnet_ids([1, 3]) - assert subnets.subscribed_subnets() == [1, 3] + assert subnets.subscribed_subnets() == [SubnetId(1), SubnetId(3)] def test_subscription_count(self) -> None: """subscription_count() returns correct count.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py b/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py index b3d3cfa1..0b466c55 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py @@ -6,7 +6,7 @@ from lean_spec.subspecs.networking.gossipsub.mcache import MessageCache, SeenCache from lean_spec.subspecs.networking.gossipsub.message import GossipsubMessage -from lean_spec.types import Bytes20 +from lean_spec.subspecs.networking.gossipsub.types import MessageId class TestMessageCacheShift: @@ -158,7 +158,7 @@ def test_get_retrieves_cached_message(self) -> None: def test_get_returns_none_for_unknown(self) -> None: """get() returns None for an unknown message ID.""" cache = MessageCache() - assert cache.get(Bytes20(b"\x00" * 20)) is None + assert cache.get(MessageId(b"\x00" * 20)) is None def test_put_duplicate_returns_false(self) -> None: """Putting the same message twice returns False on second call.""" @@ -176,7 +176,7 @@ def test_has_method(self) -> None: cache.put("t", msg) assert cache.has(msg.id) - assert not cache.has(Bytes20(b"\x00" * 20)) + assert not cache.has(MessageId(b"\x00" * 20)) class TestSeenCache: @@ -185,13 +185,13 @@ class TestSeenCache: def test_add_returns_true_for_new(self) -> None: """add() returns True for a new message ID.""" seen = SeenCache(ttl_seconds=120) - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") assert seen.add(msg_id, time.time()) is True def test_add_returns_false_for_duplicate(self) -> None: """add() returns False for an already-seen message ID.""" seen = SeenCache(ttl_seconds=120) - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") seen.add(msg_id, time.time()) assert seen.add(msg_id, time.time()) is False @@ -200,8 +200,8 @@ def test_cleanup_removes_expired(self) -> None: seen = SeenCache(ttl_seconds=10) now = time.time() - old_id = Bytes20(b"aaaaaaaaaaaaaaaaaaaa") - fresh_id = Bytes20(b"bbbbbbbbbbbbbbbbbbbb") + old_id = MessageId(b"aaaaaaaaaaaaaaaaaaaa") + fresh_id = MessageId(b"bbbbbbbbbbbbbbbbbbbb") seen.add(old_id, now - 20) seen.add(fresh_id, now) @@ -214,7 +214,7 @@ def test_cleanup_no_expired(self) -> None: """cleanup() with no expired entries removes nothing.""" seen = SeenCache(ttl_seconds=120) now = time.time() - seen.add(Bytes20(b"12345678901234567890"), now) + seen.add(MessageId(b"12345678901234567890"), now) removed = seen.cleanup(now) assert removed == 0 @@ -224,7 +224,7 @@ def test_clear_empties_all(self) -> None: """clear() removes all entries.""" seen = SeenCache() for i in range(5): - seen.add(Bytes20(f"x{i:019d}".encode()), time.time()) + seen.add(MessageId(f"x{i:019d}".encode()), time.time()) assert len(seen) == 5 seen.clear() @@ -233,11 +233,11 @@ def test_clear_empties_all(self) -> None: def test_has_method(self) -> None: """The has() method works for seen message IDs.""" seen = SeenCache() - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") seen.add(msg_id, time.time()) assert seen.has(msg_id) - assert not seen.has(Bytes20(b"\x00" * 20)) + assert not seen.has(MessageId(b"\x00" * 20)) class TestGossipsubMessageId: @@ -262,10 +262,10 @@ def test_id_differs_with_different_topic(self) -> None: assert msg1.id != msg2.id def test_id_is_20_bytes(self) -> None: - """Message ID is exactly 20 bytes (Bytes20).""" + """Message ID is exactly 20 bytes.""" msg = GossipsubMessage(topic=b"t", raw_data=b"d") assert len(msg.id) == 20 - assert isinstance(msg.id, Bytes20) + assert isinstance(msg.id, MessageId) def test_id_is_cached(self) -> None: """The ID is computed once and reused on subsequent accesses.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py b/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py index f3c2d589..209a13cb 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py @@ -24,7 +24,7 @@ Message, SubOpts, ) -from lean_spec.types import Bytes20 +from lean_spec.subspecs.networking.gossipsub.types import MessageId from .conftest import add_peer, make_behavior, make_peer @@ -217,7 +217,7 @@ async def test_ihave_ignores_seen(self) -> None: """IHAVE for already-seen messages does not trigger IWANT.""" behavior, capture = make_behavior() peer_id = add_peer(behavior, "peer1") - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") # Mark as seen behavior.seen_cache.add(msg_id, time.time()) @@ -233,7 +233,7 @@ async def test_ihave_partial_seen(self) -> None: behavior, capture = make_behavior() peer_id = add_peer(behavior, "peer1") - seen_id = Bytes20(b"seen_msg_id_1234seen") + seen_id = MessageId(b"seen_msg_id_1234seen") unseen_id = b"unseen_msg_id_12unse" behavior.seen_cache.add(seen_id, time.time()) @@ -535,7 +535,7 @@ def test_idontwant_populates_peer_set(self) -> None: state = behavior._peers[peer_id] for mid in msg_ids: - assert Bytes20(mid) in state.dont_want_ids + assert MessageId(mid) in state.dont_want_ids def test_idontwant_unknown_peer(self) -> None: """IDONTWANT for unknown peer is silently ignored.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py b/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py index 4743831e..02bb934d 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py @@ -20,7 +20,7 @@ ControlMessage, ControlPrune, ) -from lean_spec.types import Bytes20 +from lean_spec.subspecs.networking.gossipsub.types import MessageId from .conftest import add_peer, make_behavior, make_peer @@ -255,7 +255,7 @@ async def test_cleans_seen_cache(self) -> None: behavior, _ = make_behavior() behavior.seen_cache = SeenCache(ttl_seconds=1) - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") behavior.seen_cache.add(msg_id, time.time() - 10) # Already expired await behavior._heartbeat() @@ -310,7 +310,7 @@ async def test_clears_idontwant_sets(self) -> None: """Heartbeat clears per-peer IDONTWANT sets.""" behavior, _ = make_behavior() pid = add_peer(behavior, "peer1") - behavior._peers[pid].dont_want_ids.add(Bytes20(b"12345678901234567890")) + behavior._peers[pid].dont_want_ids.add(MessageId(b"12345678901234567890")) assert len(behavior._peers[pid].dont_want_ids) == 1 diff --git a/tests/lean_spec/subspecs/networking/test_network_service.py b/tests/lean_spec/subspecs/networking/test_network_service.py index 9e657f4d..027fb75e 100644 --- a/tests/lean_spec/subspecs/networking/test_network_service.py +++ b/tests/lean_spec/subspecs/networking/test_network_service.py @@ -171,7 +171,7 @@ async def test_attestation_processed_by_store( attestation = SignedAttestation( validator_id=ValidatorIndex(42), - message=AttestationData( + data=AttestationData( slot=Slot(1), head=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), target=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), @@ -218,7 +218,7 @@ async def test_attestation_ignored_in_idle_state( attestation = SignedAttestation( validator_id=ValidatorIndex(99), - message=AttestationData( + data=AttestationData( slot=Slot(1), head=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), target=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), diff --git a/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py b/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py index baa7455f..f2a4a6c6 100644 --- a/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py +++ b/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py @@ -1,7 +1,5 @@ """Tests for secp256k1 identity keypair.""" -import pytest - from lean_spec.subspecs.networking.transport.identity import ( IdentityKeypair, verify_signature, @@ -34,21 +32,10 @@ def test_generate_unique(self) -> None: def test_from_bytes_roundtrip(self) -> None: """Keypair can be loaded from raw bytes.""" original = IdentityKeypair.generate() - private_bytes = original.private_key_bytes() - - restored = IdentityKeypair.from_bytes(private_bytes) - + restored = IdentityKeypair.from_bytes(original.private_key_bytes()) assert restored.public_key_bytes() == original.public_key_bytes() assert restored.private_key_bytes() == original.private_key_bytes() - def test_from_bytes_invalid_length(self) -> None: - """Loading from invalid bytes raises ValueError.""" - with pytest.raises(ValueError, match="Expected 32 bytes"): - IdentityKeypair.from_bytes(b"\x00" * 16) - - with pytest.raises(ValueError, match="Expected 32 bytes"): - IdentityKeypair.from_bytes(b"\x00" * 64) - def test_sign_and_verify(self) -> None: """Signatures can be verified.""" keypair = IdentityKeypair.generate() diff --git a/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py b/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py index e67aaefa..3c519946 100644 --- a/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py +++ b/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py @@ -8,6 +8,7 @@ create_identity_proof, verify_identity_proof, ) +from lean_spec.types import Bytes32 class TestIdentityProof: @@ -16,25 +17,25 @@ class TestIdentityProof: def test_create_and_verify(self) -> None: """Identity proof can be verified.""" identity_key = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key, noise_public_key) + proof = create_identity_proof(identity_key, public_key) assert verify_identity_proof( identity_key.public_key_bytes(), - noise_public_key, + public_key, proof, ) - def test_verify_wrong_noise_key(self) -> None: - """Verification fails with wrong Noise key.""" + def test_verify_wrong_key(self) -> None: + """Verification fails with wrong public key.""" identity_key = IdentityKeypair.generate() - noise_public_key = os.urandom(32) - wrong_noise_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) + wrong_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key, noise_public_key) + proof = create_identity_proof(identity_key, public_key) assert not verify_identity_proof( identity_key.public_key_bytes(), - wrong_noise_key, + wrong_key, proof, ) @@ -42,45 +43,45 @@ def test_verify_wrong_identity_key(self) -> None: """Verification fails with wrong identity key.""" identity_key1 = IdentityKeypair.generate() identity_key2 = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key1, noise_public_key) + proof = create_identity_proof(identity_key1, public_key) assert not verify_identity_proof( identity_key2.public_key_bytes(), - noise_public_key, + public_key, proof, ) def test_proof_is_deterministic(self) -> None: """Same inputs produce same proof format (but not same bytes due to ECDSA k).""" identity_key = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof1 = create_identity_proof(identity_key, noise_public_key) - proof2 = create_identity_proof(identity_key, noise_public_key) + proof1 = create_identity_proof(identity_key, public_key) + proof2 = create_identity_proof(identity_key, public_key) - assert verify_identity_proof(identity_key.public_key_bytes(), noise_public_key, proof1) - assert verify_identity_proof(identity_key.public_key_bytes(), noise_public_key, proof2) + assert verify_identity_proof(identity_key.public_key_bytes(), public_key, proof1) + assert verify_identity_proof(identity_key.public_key_bytes(), public_key, proof2) def test_noise_identity_prefix(self) -> None: """NOISE_IDENTITY_PREFIX matches libp2p-noise spec.""" assert NOISE_IDENTITY_PREFIX == b"noise-libp2p-static-key:" - def test_proof_binds_identity_to_noise_key(self) -> None: + def test_proof_binds_identity_to_key(self) -> None: """Proof prevents identity key substitution.""" identity_key_real = IdentityKeypair.generate() identity_key_attacker = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key_real, noise_public_key) + proof = create_identity_proof(identity_key_real, public_key) assert verify_identity_proof( identity_key_real.public_key_bytes(), - noise_public_key, + public_key, proof, ) assert not verify_identity_proof( identity_key_attacker.public_key_bytes(), - noise_public_key, + public_key, proof, ) diff --git a/tests/lean_spec/subspecs/networking/transport/test_peer_id.py b/tests/lean_spec/subspecs/networking/transport/test_peer_id.py index a49002d8..67816801 100644 --- a/tests/lean_spec/subspecs/networking/transport/test_peer_id.py +++ b/tests/lean_spec/subspecs/networking/transport/test_peer_id.py @@ -23,6 +23,7 @@ PeerId, PublicKeyProto, ) +from lean_spec.types import Bytes33 # Protobuf tag constants for test assertions _PROTOBUF_TAG_TYPE = 0x08 # (1 << 3) | 0 = field 1, varint @@ -273,14 +274,6 @@ def test_different_keys_different_peerids(self) -> None: assert str(peer_id1) != str(peer_id2) - def test_from_secp256k1_invalid_length(self) -> None: - """from_secp256k1 rejects invalid key lengths.""" - with pytest.raises(ValueError, match="must be 33 bytes"): - PeerId.from_secp256k1(bytes(32)) - - with pytest.raises(ValueError, match="must be 33 bytes"): - PeerId.from_secp256k1(bytes(34)) - class TestPeerIdFormat: """Tests for PeerId format and structure.""" @@ -482,8 +475,8 @@ def test_known_secp256k1_peer_id(self) -> None: This matches the libp2p spec test vector. """ # From spec: 08021221037777e994e452c21604f91de093ce415f5432f701dd8cd1a7a6fea0e630bfca99 - key_data = bytes.fromhex( - "037777e994e452c21604f91de093ce415f5432f701dd8cd1a7a6fea0e630bfca99" + key_data = Bytes33( + bytes.fromhex("037777e994e452c21604f91de093ce415f5432f701dd8cd1a7a6fea0e630bfca99") ) peer_id = PeerId.from_secp256k1(key_data) diff --git a/tests/lean_spec/subspecs/ssz/test_signed_attestation.py b/tests/lean_spec/subspecs/ssz/test_signed_attestation.py index ab6965d1..17fe747a 100644 --- a/tests/lean_spec/subspecs/ssz/test_signed_attestation.py +++ b/tests/lean_spec/subspecs/ssz/test_signed_attestation.py @@ -17,7 +17,7 @@ def test_encode_decode_signed_attestation_roundtrip() -> None: ) signed_attestation = SignedAttestation( validator_id=ValidatorIndex(0), - message=attestation_data, + data=attestation_data, signature=Signature( path=HashTreeOpening(siblings=HashDigestList(data=[])), rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), diff --git a/tests/lean_spec/subspecs/sync/test_service.py b/tests/lean_spec/subspecs/sync/test_service.py index 1b5b795c..54d025e2 100644 --- a/tests/lean_spec/subspecs/sync/test_service.py +++ b/tests/lean_spec/subspecs/sync/test_service.py @@ -417,7 +417,7 @@ async def test_attestation_buffered_when_block_unknown( original_fn = sync_service.store.on_gossip_attestation def reject_unknown(signed_attestation, *, is_aggregator=False): - if signed_attestation.message.target.root == unknown_root: + if signed_attestation.data.target.root == unknown_root: raise KeyError("Unknown block") return original_fn(signed_attestation, is_aggregator=is_aggregator) diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index 0d074a40..7886f9ab 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -15,6 +15,7 @@ SignedAttestation, SignedBlockWithAttestation, ValidatorIndex, + ValidatorIndices, ) from lean_spec.subspecs.containers.attestation import AggregationBits from lean_spec.subspecs.containers.slot import Slot @@ -576,11 +577,11 @@ async def capture_attestation(attestation: SignedAttestation) -> None: for signed_att in attestations_produced: validator_id = signed_att.validator_id public_key = key_manager.get_public_key(validator_id) - message_bytes = signed_att.message.data_root_bytes() + message_bytes = signed_att.data.data_root_bytes() is_valid = TARGET_SIGNATURE_SCHEME.verify( pk=public_key, - slot=signed_att.message.slot, + slot=signed_att.data.slot, message=message_bytes, sig=signed_att.signature, ) @@ -619,7 +620,7 @@ async def capture_attestation(attestation: SignedAttestation) -> None: expected_source = store.latest_justified for signed_att in attestations_produced: - data = signed_att.message + data = signed_att.data # Verify head checkpoint references the store's head assert data.head.root == expected_head_root @@ -711,7 +712,9 @@ async def test_block_includes_pending_attestations( attestation_map[vid] = attestation_data proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants) + ), public_keys=public_keys, signatures=signatures, message=data_root, @@ -774,7 +777,7 @@ async def test_multiple_slots_produce_different_attestations( } async def capture_attestation(attestation: SignedAttestation) -> None: - attestations_by_slot[attestation.message.slot].append(attestation) + attestations_by_slot[attestation.data.slot].append(attestation) service = ValidatorService( sync_service=real_sync_service, @@ -792,9 +795,9 @@ async def capture_attestation(attestation: SignedAttestation) -> None: # Attestations at each slot should have the correct slot value for att in attestations_by_slot[Slot(1)]: - assert att.message.slot == Slot(1) + assert att.data.slot == Slot(1) for att in attestations_by_slot[Slot(2)]: - assert att.message.slot == Slot(2) + assert att.data.slot == Slot(2) async def test_proposer_does_not_double_attest( self, @@ -921,7 +924,7 @@ async def capture_attestation(attestation: SignedAttestation) -> None: for signed_att in attestations_produced: validator_id = signed_att.validator_id public_key = key_manager.get_public_key(validator_id) - message_bytes = signed_att.message.data_root_bytes() + message_bytes = signed_att.data.data_root_bytes() # Verification must use the same slot that was used for signing is_valid = TARGET_SIGNATURE_SCHEME.verify(