From 51efc7a68bbb5207c3f403afcdbe11747e00f04a Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 17:35:17 +0100 Subject: [PATCH 1/2] chore: python and typing improvements --- .claude/ralph-loop.local.md | 9 + CLAUDE.md | 24 +++ .../test_fixtures/fork_choice.py | 9 +- .../test_fixtures/verify_signatures.py | 12 +- src/lean_spec/__main__.py | 3 +- src/lean_spec/subspecs/chain/clock.py | 2 +- .../attestation/aggregation_bits.py | 12 +- .../containers/attestation/attestation.py | 14 +- .../subspecs/containers/state/state.py | 92 ++++++++-- .../subspecs/containers/state/types.py | 2 +- src/lean_spec/subspecs/forkchoice/store.py | 77 +------- src/lean_spec/subspecs/networking/config.py | 6 +- .../subspecs/networking/discovery/crypto.py | 32 +--- .../networking/discovery/handshake.py | 63 +++---- .../subspecs/networking/discovery/keys.py | 15 +- .../subspecs/networking/discovery/messages.py | 9 +- .../subspecs/networking/discovery/packet.py | 52 ++---- .../subspecs/networking/discovery/service.py | 5 +- .../subspecs/networking/discovery/session.py | 18 +- .../networking/discovery/transport.py | 40 ++--- src/lean_spec/subspecs/networking/enr/enr.py | 12 +- src/lean_spec/subspecs/networking/enr/eth2.py | 4 +- src/lean_spec/subspecs/networking/enr/keys.py | 2 +- .../subspecs/networking/gossipsub/behavior.py | 8 +- .../subspecs/networking/gossipsub/message.py | 3 +- .../subspecs/networking/gossipsub/types.py | 13 +- src/lean_spec/subspecs/networking/peer.py | 8 +- .../subspecs/networking/service/service.py | 2 +- .../networking/transport/identity/keypair.py | 28 ++- .../transport/identity/signature.py | 33 ++-- .../subspecs/networking/transport/peer_id.py | 11 +- src/lean_spec/subspecs/networking/types.py | 28 +-- src/lean_spec/subspecs/node/node.py | 2 +- src/lean_spec/subspecs/ssz/hash.py | 4 +- src/lean_spec/subspecs/sync/block_cache.py | 2 +- .../subspecs/sync/checkpoint_sync.py | 4 +- src/lean_spec/subspecs/sync/head_sync.py | 2 +- src/lean_spec/subspecs/sync/states.py | 2 +- src/lean_spec/subspecs/validator/service.py | 2 +- src/lean_spec/subspecs/xmss/containers.py | 6 +- src/lean_spec/subspecs/xmss/interface.py | 2 +- src/lean_spec/subspecs/xmss/message_hash.py | 2 +- src/lean_spec/subspecs/xmss/poseidon.py | 2 +- src/lean_spec/subspecs/xmss/prf.py | 2 +- src/lean_spec/subspecs/xmss/subtree.py | 12 +- src/lean_spec/subspecs/xmss/tweak_hash.py | 2 +- .../devnet/ssz/test_consensus_containers.py | 2 +- tests/lean_spec/conftest.py | 5 +- tests/lean_spec/helpers/builders.py | 12 +- .../test_attestation_aggregation.py | 8 +- .../test_state_process_attestations.py | 6 +- .../forkchoice/test_attestation_target.py | 14 +- .../forkchoice/test_store_attestations.py | 32 ++-- .../subspecs/forkchoice/test_store_pruning.py | 6 +- .../forkchoice/test_time_management.py | 11 +- .../subspecs/forkchoice/test_validator.py | 24 +-- .../subspecs/networking/discovery/conftest.py | 19 +- .../networking/discovery/test_crypto.py | 32 ---- .../networking/discovery/test_handshake.py | 84 ++++----- .../networking/discovery/test_integration.py | 35 ++-- .../networking/discovery/test_keys.py | 17 -- .../networking/discovery/test_packet.py | 111 +++--------- .../networking/discovery/test_routing.py | 13 +- .../networking/discovery/test_session.py | 69 +++----- .../networking/discovery/test_transport.py | 32 ++-- .../networking/discovery/test_vectors.py | 165 +++++++++--------- .../subspecs/networking/enr/test_eth2.py | 22 +-- .../gossipsub/test_cache_edge_cases.py | 26 +-- .../networking/gossipsub/test_handlers.py | 8 +- .../networking/gossipsub/test_heartbeat.py | 6 +- .../networking/test_network_service.py | 4 +- .../transport/identity/test_keypair.py | 15 +- .../transport/identity/test_signature.py | 45 ++--- .../networking/transport/test_peer_id.py | 13 +- .../subspecs/ssz/test_signed_attestation.py | 2 +- tests/lean_spec/subspecs/sync/test_service.py | 2 +- .../subspecs/validator/test_service.py | 19 +- 77 files changed, 685 insertions(+), 868 deletions(-) create mode 100644 .claude/ralph-loop.local.md diff --git a/.claude/ralph-loop.local.md b/.claude/ralph-loop.local.md new file mode 100644 index 00000000..305fe8f0 --- /dev/null +++ b/.claude/ralph-loop.local.md @@ -0,0 +1,9 @@ +--- +active: true +iteration: 137 +max_iterations: 0 +completion_promise: "DONE" +started_at: "2026-02-20T14:14:04Z" +--- + +Can you go with the py architect to sanity check the codebase everywhere to check that the code adheres to all the most modern Python principles. It must be extremely lean, clean, and compact so that it truly serves as a minimal running specification/client that is considered a reference worldwide. Therefore, take inspiration from the best repositories in the world, or even from the Python compiler itself, and create something perfect. Of course, as a specification, it's crucial that things are organized correctly and clearly, so we'll prioritize, for example, storing functions in objects rather than isolated functions, while avoiding excessive abstraction to prevent overly complexifying the codebase. diff --git a/CLAUDE.md b/CLAUDE.md index ace6f553..456213b7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -103,6 +103,30 @@ from cryptography.hazmat.primitives.asymmetric import x25519 This keeps code readable and avoids mental overhead of tracking renamed imports. +### Type Annotations + +**Never quote type annotations when `from __future__ import annotations` is present.** With future annotations, all annotations are already lazy strings. Adding quotes is redundant and noisy. + +Bad: +```python +from __future__ import annotations + +def create(cls) -> "Store": + ... +``` + +Good: +```python +from __future__ import annotations + +def create(cls) -> Store: + ... +``` + +The only valid use of quoted annotations is in files that do NOT have `from __future__ import annotations` and need a forward reference. Prefer adding the future import instead. + +**Prefer narrow domain types over raw builtins.** Use `Bytes32`, `Bytes33`, `Bytes52` etc. instead of `bytes` in signatures. Spec code should never accept or return `bytes` when a more specific type exists. + ### Module-Level Constants Use docstrings (not comments) to document module-level constants. Place the docstring immediately after the assignment. diff --git a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py index 4f67712b..30693e24 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py +++ b/packages/testing/src/consensus_testing/test_fixtures/fork_choice.py @@ -157,7 +157,7 @@ def set_max_slot_default(self) -> Self: if isinstance(step, BlockStep): max_slot_value = max(max_slot_value, step.block.slot) elif isinstance(step, AttestationStep): - max_slot_value = max(max_slot_value, step.attestation.message.slot) + max_slot_value = max(max_slot_value, step.attestation.data.slot) self.max_slot = max_slot_value @@ -219,9 +219,8 @@ def make_fixture(self) -> Self: # # The Store is the node's local view of the chain. # It starts from a trusted anchor (usually genesis). - store = Store.get_forkchoice_store( - anchor_state=self.anchor_state, - anchor_block=self.anchor_block, + store = self.anchor_state.to_forkchoice_store( + self.anchor_block, validator_id=DEFAULT_VALIDATOR_ID, ) @@ -400,7 +399,7 @@ def _build_block_from_spec( working_store = working_store.on_gossip_attestation( SignedAttestation( validator_id=attestation.validator_id, - message=attestation.data, + data=attestation.data, signature=signature, ), scheme=LEAN_ENV_TO_SCHEMES[self.lean_env], diff --git a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py index 0bbc614b..32331602 100644 --- a/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py +++ b/packages/testing/src/consensus_testing/test_fixtures/verify_signatures.py @@ -23,7 +23,7 @@ ) from lean_spec.subspecs.containers.checkpoint import Checkpoint from lean_spec.subspecs.containers.state.state import State -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.ssz import hash_tree_root from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof from lean_spec.types import Bytes32 @@ -42,7 +42,7 @@ def _create_dummy_aggregated_proof(validator_ids: list[ValidatorIndex]) -> Aggre so it will fail verification. """ return AggregatedSignatureProof( - participants=AggregationBits.from_validator_indices(validator_ids), + participants=AggregationBits.from_validator_indices(ValidatorIndices(data=validator_ids)), proof_data=ByteListMiB(data=b"\x00" * 32), # Invalid proof bytes ) @@ -216,7 +216,9 @@ def _build_block_from_spec( data_root = attestation_data.data_root_bytes() # Create aggregated attestation claiming validator_ids as participants - aggregation_bits = AggregationBits.from_validator_indices(invalid_spec.validator_ids) + aggregation_bits = AggregationBits.from_validator_indices( + ValidatorIndices(data=invalid_spec.validator_ids) + ) invalid_aggregated = AggregatedAttestation( aggregation_bits=aggregation_bits, data=attestation_data, @@ -238,7 +240,9 @@ def _build_block_from_spec( ] # Create valid aggregated proof from actual signers valid_proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(invalid_spec.signer_ids), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=invalid_spec.signer_ids) + ), public_keys=signer_public_keys, signatures=signer_signatures, message=data_root, diff --git a/src/lean_spec/__main__.py b/src/lean_spec/__main__.py index b98a6ca2..fc666070 100644 --- a/src/lean_spec/__main__.py +++ b/src/lean_spec/__main__.py @@ -36,7 +36,6 @@ from lean_spec.subspecs.containers.block.types import AggregatedAttestations from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.validator import SubnetId -from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.genesis import GenesisConfig from lean_spec.subspecs.networking.client import LiveNetworkEventSource from lean_spec.subspecs.networking.enr import ENR @@ -271,7 +270,7 @@ async def _init_from_checkpoint( # The store treats this as the new "genesis" for fork choice purposes. # All blocks before the checkpoint are effectively pruned. validator_id = validator_registry.primary_index() if validator_registry else None - store = Store.get_forkchoice_store(state, anchor_block, validator_id) + store = state.to_forkchoice_store(anchor_block, validator_id) logger.info( "Initialized from checkpoint at slot %d (finalized=%s)", state.slot, diff --git a/src/lean_spec/subspecs/chain/clock.py b/src/lean_spec/subspecs/chain/clock.py index 0ef5502c..e8a2f44a 100644 --- a/src/lean_spec/subspecs/chain/clock.py +++ b/src/lean_spec/subspecs/chain/clock.py @@ -9,9 +9,9 @@ """ import asyncio +from collections.abc import Callable from dataclasses import dataclass from time import time as wall_time -from typing import Callable from lean_spec.subspecs.containers.slot import Slot from lean_spec.types import Uint64 diff --git a/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py b/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py index 3b12ea31..097bb527 100644 --- a/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py +++ b/src/lean_spec/subspecs/containers/attestation/aggregation_bits.py @@ -19,15 +19,12 @@ class AggregationBits(BaseBitlist): LIMIT = int(VALIDATOR_REGISTRY_LIMIT) @classmethod - def from_validator_indices( - cls, indices: "ValidatorIndices | list[ValidatorIndex]" - ) -> AggregationBits: + def from_validator_indices(cls, indices: ValidatorIndices) -> AggregationBits: """ Construct aggregation bits from a set of validator indices. Args: - indices: Validator indices to set in the bitlist. Accepts either - a ValidatorIndices collection or a plain list of ValidatorIndex. + indices: Validator indices to set in the bitlist. Returns: AggregationBits with the corresponding indices set to True. @@ -36,8 +33,7 @@ def from_validator_indices( AssertionError: If no indices are provided. AssertionError: If any index is outside the supported LIMIT. """ - # Extract list from ValidatorIndices if needed - index_list = indices.data if isinstance(indices, ValidatorIndices) else indices + index_list = indices.data # Require at least one validator for a valid aggregation. if not index_list: @@ -57,7 +53,7 @@ def from_validator_indices( # - False elsewhere. return cls(data=[Boolean(i in ids) for i in range(max_id + 1)]) - def to_validator_indices(self) -> "ValidatorIndices": + def to_validator_indices(self) -> ValidatorIndices: """ Extract all validator indices encoded in these aggregation bits. diff --git a/src/lean_spec/subspecs/containers/attestation/attestation.py b/src/lean_spec/subspecs/containers/attestation/attestation.py index 092e355b..89efa4ad 100644 --- a/src/lean_spec/subspecs/containers/attestation/attestation.py +++ b/src/lean_spec/subspecs/containers/attestation/attestation.py @@ -16,7 +16,7 @@ from collections import defaultdict from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.ssz import hash_tree_root from lean_spec.types import Bytes32, Container @@ -56,15 +56,9 @@ class Attestation(Container): """The attestation data produced by the validator.""" -class SignedAttestation(Container): +class SignedAttestation(Attestation): """Validator attestation bundled with its signature.""" - validator_id: ValidatorIndex - """The index of the validator making the attestation.""" - - message: AttestationData - """The attestation message signed by the validator.""" - signature: Signature """Signature aggregation produced by the leanVM (SNARKs in the future).""" @@ -103,7 +97,9 @@ def aggregate_by_data( return [ cls( - aggregation_bits=AggregationBits.from_validator_indices(validator_ids), + aggregation_bits=AggregationBits.from_validator_indices( + ValidatorIndices(data=validator_ids) + ), data=data, ) for data, validator_ids in data_to_validator_ids.items() diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index 57d6e2bb..4b54b08d 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -2,8 +2,10 @@ from __future__ import annotations -from typing import AbstractSet, Collection, Iterable +from typing import TYPE_CHECKING, AbstractSet, Collection, Iterable +from lean_spec.subspecs.chain.clock import Interval +from lean_spec.subspecs.chain.config import INTERVALS_PER_SLOT from lean_spec.subspecs.ssz.hash import hash_tree_root from lean_spec.subspecs.xmss.aggregation import ( AggregatedSignatureProof, @@ -24,7 +26,7 @@ from ..checkpoint import Checkpoint from ..config import Config from ..slot import Slot -from ..validator import ValidatorIndex +from ..validator import ValidatorIndex, ValidatorIndices from .types import ( HistoricalBlockHashes, JustificationRoots, @@ -33,6 +35,9 @@ Validators, ) +if TYPE_CHECKING: + from lean_spec.subspecs.forkchoice import Store + class State(Container): """The main consensus state object.""" @@ -73,7 +78,7 @@ class State(Container): """A bitlist of validators who participated in justifications.""" @classmethod - def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "State": + def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> State: """ Generate a genesis state with empty history and proper initial values. @@ -117,7 +122,72 @@ def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "Stat justifications_validators=JustificationValidators(data=[]), ) - def process_slots(self, target_slot: Slot) -> "State": + def to_forkchoice_store( + self, + anchor_block: Block, + validator_id: ValidatorIndex | None, + ) -> Store: + """ + Initialize a forkchoice store from this state and an anchor block. + + The anchor block and this state form the starting point for fork choice. + Both are treated as justified and finalized. + + Args: + anchor_block: A trusted block (e.g. genesis or checkpoint). + validator_id: Index of the validator running this store. + + Returns: + A new Store instance, ready to accept blocks and attestations. + + Raises: + AssertionError: + If the anchor block's state root does not match the hash + of this state. + """ + from lean_spec.subspecs.forkchoice import Store + + # Compute the SSZ root of this state. + # + # This is the canonical hash that should appear in the block's state root. + computed_state_root = hash_tree_root(self) + + # Check that the block actually points to this state. + # + # If this fails, the caller has supplied inconsistent inputs. + assert anchor_block.state_root == computed_state_root, ( + "Anchor block state root must match anchor state hash" + ) + + # Compute the SSZ root of the anchor block itself. + # + # This root will be used as: + # - the key in the blocks/states maps, + # - the initial head, + # - the root of the initial checkpoints. + anchor_root = hash_tree_root(anchor_block) + + # Read the slot at which the anchor block was proposed. + anchor_slot = anchor_block.slot + + # Initialize checkpoints from this state. + # + # We explicitly set the root to the anchor block root. + # The state internally might have zero-hash checkpoints (if genesis), + # but the Store must treat the anchor block as the justified/finalized point. + return Store( + time=Interval(anchor_slot * INTERVALS_PER_SLOT), + config=self.config, + head=anchor_root, + safe_target=anchor_root, + latest_justified=self.latest_justified.model_copy(update={"root": anchor_root}), + latest_finalized=self.latest_finalized.model_copy(update={"root": anchor_root}), + blocks={anchor_root: anchor_block}, + states={anchor_root: self}, + validator_id=validator_id, + ) + + def process_slots(self, target_slot: Slot) -> State: """ Advance the state through empty slots up to, but not including, target_slot. @@ -192,7 +262,7 @@ def process_slots(self, target_slot: Slot) -> "State": # Reached the target slot. Return the advanced state. return state - def process_block_header(self, block: Block) -> "State": + def process_block_header(self, block: Block) -> State: """ Validate the block header and update header-linked state. @@ -341,7 +411,7 @@ def process_block_header(self, block: Block) -> "State": } ) - def process_block(self, block: Block) -> "State": + def process_block(self, block: Block) -> State: """ Apply full block processing including header and body. @@ -368,7 +438,7 @@ def process_block(self, block: Block) -> "State": def process_attestations( self, attestations: Iterable[AggregatedAttestation], - ) -> "State": + ) -> State: """ Apply attestations and update justification/finalization according to the Lean Consensus 3SF-mini rules. @@ -617,7 +687,7 @@ def process_attestations( } ) - def state_transition(self, block: Block, valid_signatures: bool = True) -> "State": + def state_transition(self, block: Block, valid_signatures: bool = True) -> State: """ Apply the complete state transition function for a block. @@ -670,7 +740,7 @@ def build_block( available_attestations: Iterable[Attestation] | None = None, known_block_roots: AbstractSet[Bytes32] | None = None, aggregated_payloads: dict[SignatureKey, list[AggregatedSignatureProof]] | None = None, - ) -> tuple[Block, "State", list[AggregatedAttestation], list[AggregatedSignatureProof]]: + ) -> tuple[Block, State, list[AggregatedAttestation], list[AggregatedSignatureProof]]: """ Build a valid block on top of this state. @@ -858,7 +928,9 @@ def aggregate_gossip_signatures( # The aggregation combines multiple XMSS signatures into a single # compact proof that can verify all participants signed the message. if gossip_ids: - participants = AggregationBits.from_validator_indices(gossip_ids) + participants = AggregationBits.from_validator_indices( + ValidatorIndices(data=gossip_ids) + ) proof = AggregatedSignatureProof.aggregate( participants=participants, public_keys=gossip_keys, diff --git a/src/lean_spec/subspecs/containers/state/types.py b/src/lean_spec/subspecs/containers/state/types.py index db53bbee..631e6648 100644 --- a/src/lean_spec/subspecs/containers/state/types.py +++ b/src/lean_spec/subspecs/containers/state/types.py @@ -70,7 +70,7 @@ def with_justified( finalized_slot: Slot, target_slot: Slot, value: Boolean, - ) -> "JustifiedSlots": + ) -> JustifiedSlots: """ Return a new bitfield with the justification status updated. diff --git a/src/lean_spec/subspecs/forkchoice/store.py b/src/lean_spec/subspecs/forkchoice/store.py index 827b93db..58b424c6 100644 --- a/src/lean_spec/subspecs/forkchoice/store.py +++ b/src/lean_spec/subspecs/forkchoice/store.py @@ -154,76 +154,6 @@ class Store(Container): Used for recursive signature aggregation when building blocks. """ - @classmethod - def get_forkchoice_store( - cls, - anchor_state: State, - anchor_block: Block, - validator_id: ValidatorIndex | None, - ) -> "Store": - """ - Initialize forkchoice store from an anchor state and block. - - The anchor block and its state form the starting point for fork choice. - We treat this anchor as both justified and finalized. - - Args: - anchor_state: The state corresponding to the anchor block. - anchor_block: A trusted block (e.g. genesis or checkpoint). - validator_id: Index of the validator running this store. - - Returns: - A new Store instance, ready to accept blocks and attestations. - - Raises: - AssertionError: - If the anchor block's state root does not match the hash - of the provided state. - """ - # Compute the SSZ root of the given state. - # - # This is the canonical hash that should appear in the block's state root. - computed_state_root = hash_tree_root(anchor_state) - - # Check that the block actually points to this state. - # - # If this fails, the caller has supplied inconsistent inputs. - assert anchor_block.state_root == computed_state_root, ( - "Anchor block state root must match anchor state hash" - ) - - # Compute the SSZ root of the anchor block itself. - # - # This root will be used as: - # - the key in the blocks/states maps, - # - the initial head, - # - the root of the initial checkpoints. - anchor_root = hash_tree_root(anchor_block) - - # Read the slot at which the anchor block was proposed. - anchor_slot = anchor_block.slot - - # Build an initial checkpoint using the anchor block. - # - # Both the root and the slot come directly from the anchor. - # Initialize checkpoints from the anchor state - # - # We explicitly set the root to the anchor block root. - # The anchor state internally might have zero-hash checkpoints (if genesis), - # but the Store must treat the anchor block as the justified/finalized point. - - return cls( - time=Interval(anchor_slot * INTERVALS_PER_SLOT), - config=anchor_state.config, - head=anchor_root, - safe_target=anchor_root, - latest_justified=anchor_state.latest_justified.model_copy(update={"root": anchor_root}), - latest_finalized=anchor_state.latest_finalized.model_copy(update={"root": anchor_root}), - blocks={anchor_root: anchor_block}, - states={anchor_root: anchor_state}, - validator_id=validator_id, - ) - def prune_stale_attestation_data(self) -> "Store": """ Remove attestation data that can no longer influence fork choice. @@ -376,7 +306,7 @@ def on_gossip_attestation( AssertionError: If signature verification fails. """ validator_id = signed_attestation.validator_id - attestation_data = signed_attestation.message + attestation_data = signed_attestation.data signature = signed_attestation.signature # Validate the attestation first so unknown blocks are rejected cleanly @@ -651,8 +581,7 @@ def on_block( # Store proposer signature for future lookup if it belongs to the same committee # as the current validator. if self.validator_id is not None: - proposer_validator_id = proposer_attestation.validator_id - proposer_subnet_id = proposer_validator_id.compute_subnet_id( + proposer_subnet_id = proposer_attestation.validator_id.compute_subnet_id( ATTESTATION_COMMITTEE_COUNT ) current_validator_subnet_id = self.validator_id.compute_subnet_id( @@ -1010,7 +939,7 @@ def aggregate_committee_signatures(self) -> tuple["Store", list[SignedAggregated # Each SignatureKey contains (validator_id, data_root) # We look up the full AttestationData from attestation_data_by_root attestation_list: list[Attestation] = [] - for sig_key in self.gossip_signatures.keys(): + for sig_key in self.gossip_signatures: data_root = sig_key.data_root attestation_data = self.attestation_data_by_root.get(data_root) if attestation_data is not None: diff --git a/src/lean_spec/subspecs/networking/config.py b/src/lean_spec/subspecs/networking/config.py index 51e0cf36..8750dc3e 100644 --- a/src/lean_spec/subspecs/networking/config.py +++ b/src/lean_spec/subspecs/networking/config.py @@ -2,8 +2,6 @@ from typing import Final -from lean_spec.types.byte_arrays import Bytes1 - from .types import DomainType # --- Request/Response Limits --- @@ -32,13 +30,13 @@ # --- Gossip Message Domains --- -MESSAGE_DOMAIN_INVALID_SNAPPY: Final[DomainType] = Bytes1(b"\x00") +MESSAGE_DOMAIN_INVALID_SNAPPY: Final[DomainType] = DomainType(b"\x00") """1-byte domain for gossip message-id isolation of invalid snappy messages. Per Ethereum spec, prepended to the message hash when decompression fails. """ -MESSAGE_DOMAIN_VALID_SNAPPY: Final[DomainType] = Bytes1(b"\x01") +MESSAGE_DOMAIN_VALID_SNAPPY: Final[DomainType] = DomainType(b"\x01") """1-byte domain for gossip message-id isolation of valid snappy messages. Per Ethereum spec, prepended to the message hash when decompression succeeds. diff --git a/src/lean_spec/subspecs/networking/discovery/crypto.py b/src/lean_spec/subspecs/networking/discovery/crypto.py index 62521446..e9fcf29a 100644 --- a/src/lean_spec/subspecs/networking/discovery/crypto.py +++ b/src/lean_spec/subspecs/networking/discovery/crypto.py @@ -155,11 +155,6 @@ def aes_ctr_encrypt(key: Bytes16, iv: Bytes16, plaintext: bytes) -> bytes: Returns: Ciphertext of same length as plaintext. """ - if len(key) != AES_KEY_SIZE: - raise ValueError(f"Key must be {AES_KEY_SIZE} bytes, got {len(key)}") - if len(iv) != CTR_IV_SIZE: - raise ValueError(f"IV must be {CTR_IV_SIZE} bytes, got {len(iv)}") - cipher = Cipher(algorithms.AES(key), modes.CTR(iv)) encryptor = cipher.encryptor() return encryptor.update(plaintext) + encryptor.finalize() @@ -198,11 +193,6 @@ def aes_gcm_encrypt(key: Bytes16, nonce: Bytes12, plaintext: bytes, aad: bytes) Returns: Ciphertext with 16-byte authentication tag appended. """ - if len(key) != AES_KEY_SIZE: - raise ValueError(f"Key must be {AES_KEY_SIZE} bytes, got {len(key)}") - if len(nonce) != GCM_NONCE_SIZE: - raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - aesgcm = AESGCM(key) return aesgcm.encrypt(nonce, plaintext, aad) @@ -225,16 +215,11 @@ def aes_gcm_decrypt(key: Bytes16, nonce: Bytes12, ciphertext: bytes, aad: bytes) Raises: cryptography.exceptions.InvalidTag: If authentication fails. """ - if len(key) != AES_KEY_SIZE: - raise ValueError(f"Key must be {AES_KEY_SIZE} bytes, got {len(key)}") - if len(nonce) != GCM_NONCE_SIZE: - raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - aesgcm = AESGCM(key) return aesgcm.decrypt(nonce, ciphertext, aad) -def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: bytes) -> Bytes33: +def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: Bytes33 | Bytes65) -> Bytes33: """ Perform secp256k1 ECDH key agreement. @@ -252,9 +237,6 @@ def ecdh_agree(private_key_bytes: Bytes32, public_key_bytes: bytes) -> Bytes33: Returns: 33-byte shared secret (compressed point from ECDH). """ - if len(private_key_bytes) != 32: - raise ValueError(f"Private key must be 32 bytes, got {len(private_key_bytes)}") - scalar = int.from_bytes(private_key_bytes, "big") point = _decompress_pubkey(public_key_bytes) result = _point_mul(scalar, point) @@ -287,7 +269,7 @@ def generate_secp256k1_keypair() -> tuple[Bytes32, Bytes33]: return Bytes32(private_bytes), Bytes33(public_bytes) -def pubkey_to_uncompressed(public_key_bytes: bytes) -> Bytes65: +def pubkey_to_uncompressed(public_key_bytes: Bytes33 | Bytes65) -> Bytes65: """ Convert any secp256k1 public key to uncompressed format. @@ -337,11 +319,6 @@ def sign_id_nonce( Returns: 64-byte signature (r || s, each 32 bytes). """ - if len(private_key_bytes) != 32: - raise ValueError(f"Private key must be 32 bytes, got {len(private_key_bytes)}") - if len(dest_node_id) != 32: - raise ValueError(f"Dest node ID must be 32 bytes, got {len(dest_node_id)}") - # The signing input binds several values together per the spec: # # - Domain separator prevents cross-protocol signature reuse @@ -404,11 +381,6 @@ def verify_id_nonce_signature( Returns: True if signature is valid, False otherwise. """ - if len(signature) != ID_SIGNATURE_SIZE: - return False - if len(dest_node_id) != 32: - return False - # Build the signing input per spec: # domain-separator || challenge-data || ephemeral-pubkey || node-id-B input_data = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py index 2848ac9e..146e15ab 100644 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ b/src/lean_spec/subspecs/networking/discovery/handshake.py @@ -32,7 +32,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes32, Bytes33, Bytes64 +from lean_spec.types import Bytes16, Bytes32, Bytes33 from .config import HANDSHAKE_TIMEOUT_SECS from .crypto import ( @@ -41,7 +41,7 @@ verify_id_nonce_signature, ) from .keys import derive_keys_from_pubkey -from .messages import PacketFlag, Port +from .messages import IdNonce, Nonce, PacketFlag, Port from .packet import ( HandshakeAuthdata, WhoAreYouAuthdata, @@ -88,16 +88,16 @@ class PendingHandshake: remote_node_id: NodeId """32-byte node ID of the remote peer.""" - id_nonce: bytes | None = None + id_nonce: IdNonce | None = None """16-byte challenge nonce (set when WHOAREYOU sent/received).""" challenge_data: bytes | None = None """Full WHOAREYOU packet data for key derivation (masking-iv || static-header || authdata).""" - ephemeral_privkey: bytes | None = None + ephemeral_privkey: Bytes32 | None = None """32-byte ephemeral private key (set when we send HANDSHAKE).""" - challenge_nonce: bytes | None = None + challenge_nonce: Nonce | None = None """12-byte nonce from the packet that triggered WHOAREYOU.""" remote_enr_seq: SeqNumber = SeqNumber(0) @@ -111,7 +111,7 @@ def is_expired(self, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS) -> bool: return time.time() - self.started_at > timeout_secs -@dataclass +@dataclass(frozen=True, slots=True) class HandshakeResult: """Result of a completed handshake.""" @@ -145,18 +145,13 @@ class HandshakeManager: def __init__( self, local_node_id: NodeId, - local_private_key: bytes, + local_private_key: Bytes32, local_enr_rlp: bytes, local_enr_seq: SeqNumber, session_cache: SessionCache, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS, ): """Initialize handshake manager.""" - if len(local_node_id) != 32: - raise ValueError(f"Local node ID must be 32 bytes, got {len(local_node_id)}") - if len(local_private_key) != 32: - raise ValueError(f"Local private key must be 32 bytes, got {len(local_private_key)}") - self._local_node_id = local_node_id self._local_private_key = local_private_key self._local_enr_rlp = local_enr_rlp @@ -206,10 +201,10 @@ def start_handshake(self, remote_node_id: NodeId) -> PendingHandshake: def create_whoareyou( self, remote_node_id: NodeId, - request_nonce: bytes, + request_nonce: Nonce, remote_enr_seq: SeqNumber, - masking_iv: bytes, - ) -> tuple[bytes, bytes, bytes, bytes]: + masking_iv: Bytes16, + ) -> tuple[bytes, bytes, Nonce, bytes]: """ Create a WHOAREYOU packet in response to an undecryptable message. @@ -229,20 +224,20 @@ def create_whoareyou( - challenge_data: Full data for key derivation (masking-iv || static-header || authdata) """ id_nonce = generate_id_nonce() - authdata = encode_whoareyou_authdata(bytes(id_nonce), remote_enr_seq) + authdata = encode_whoareyou_authdata(id_nonce, remote_enr_seq) # Build challenge_data per spec: masking-iv || static-header || authdata. # # This data becomes the HKDF salt for session key derivation. # Both sides must use identical challenge_data to derive matching keys. static_header = encode_static_header(PacketFlag.WHOAREYOU, request_nonce, len(authdata)) - challenge_data = masking_iv + static_header + authdata + challenge_data = bytes(masking_iv) + static_header + authdata with self._lock: pending = PendingHandshake( state=HandshakeState.SENT_WHOAREYOU, remote_node_id=remote_node_id, - id_nonce=bytes(id_nonce), + id_nonce=id_nonce, challenge_data=challenge_data, challenge_nonce=request_nonce, remote_enr_seq=remote_enr_seq, @@ -255,11 +250,11 @@ def create_handshake_response( self, remote_node_id: NodeId, whoareyou: WhoAreYouAuthdata, - remote_pubkey: bytes, + remote_pubkey: Bytes33, challenge_data: bytes, remote_ip: str = "", remote_port: Port = _DEFAULT_PORT, - ) -> tuple[bytes, bytes, bytes]: + ) -> tuple[bytes, Bytes16, Bytes16]: """ Create a HANDSHAKE packet in response to WHOAREYOU. @@ -288,10 +283,10 @@ def create_handshake_response( # Per spec, the signature input includes the full challenge_data (not just id_nonce) # to bind the signature to this specific WHOAREYOU exchange. id_signature = sign_id_nonce( - Bytes32(self._local_private_key), + self._local_private_key, challenge_data, eph_pubkey, - Bytes32(remote_node_id), + remote_node_id, ) # Include our ENR if the remote's known seq is stale. @@ -314,8 +309,8 @@ def create_handshake_response( send_key, recv_key = derive_keys_from_pubkey( local_private_key=eph_privkey, remote_public_key=remote_pubkey, - local_node_id=Bytes32(self._local_node_id), - remote_node_id=Bytes32(remote_node_id), + local_node_id=self._local_node_id, + remote_node_id=remote_node_id, challenge_data=challenge_data, is_initiator=True, ) @@ -402,11 +397,11 @@ def handle_handshake( # The signature was computed over challenge_data (not just id_nonce), # and includes our node_id as the WHOAREYOU sender (node-id-B). if not verify_id_nonce_signature( - signature=Bytes64(handshake.id_signature), + signature=handshake.id_signature, challenge_data=challenge_data, - ephemeral_pubkey=Bytes33(handshake.eph_pubkey), - dest_node_id=Bytes32(self._local_node_id), - public_key_bytes=Bytes33(remote_pubkey), + ephemeral_pubkey=handshake.eph_pubkey, + dest_node_id=self._local_node_id, + public_key_bytes=remote_pubkey, ): raise HandshakeError("Invalid ID signature") @@ -415,10 +410,10 @@ def handle_handshake( # The challenge_data was saved when we sent WHOAREYOU. # Using the same data ensures both sides derive identical keys. send_key, recv_key = derive_keys_from_pubkey( - local_private_key=Bytes32(self._local_private_key), + local_private_key=self._local_private_key, remote_public_key=handshake.eph_pubkey, - local_node_id=Bytes32(self._local_node_id), - remote_node_id=Bytes32(remote_node_id), + local_node_id=self._local_node_id, + remote_node_id=remote_node_id, challenge_data=challenge_data, is_initiator=False, ) @@ -471,7 +466,7 @@ def cleanup_expired(self) -> int: del self._pending[node_id] return len(expired) - def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes | None: + def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> Bytes33 | None: """ Retrieve the remote node's static public key for signature verification. @@ -503,7 +498,7 @@ def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes # to ensure the ENR belongs to who we think sent it. computed_id = enr.compute_node_id() if computed_id is not None and bytes(computed_id) == node_id: - return bytes(enr.public_key) + return Bytes33(enr.public_key) except (ValueError, KeyError, IndexError): pass @@ -513,7 +508,7 @@ def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes # Use it if the handshake packet did not include an ENR. cached_enr = self._enr_cache.get(node_id) if cached_enr is not None and cached_enr.public_key is not None: - return bytes(cached_enr.public_key) + return Bytes33(cached_enr.public_key) return None diff --git a/src/lean_spec/subspecs/networking/discovery/keys.py b/src/lean_spec/subspecs/networking/discovery/keys.py index c664c9cb..dc974823 100644 --- a/src/lean_spec/subspecs/networking/discovery/keys.py +++ b/src/lean_spec/subspecs/networking/discovery/keys.py @@ -30,7 +30,7 @@ from Crypto.Hash import keccak -from lean_spec.types import Bytes16, Bytes32, Bytes33 +from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes65 from .crypto import ecdh_agree, pubkey_to_uncompressed @@ -75,13 +75,6 @@ def derive_keys( The initiator uses initiator_key to encrypt and recipient_key to decrypt. The recipient uses recipient_key to encrypt and initiator_key to decrypt. """ - if len(secret) != 33: - raise ValueError(f"Secret must be 33 bytes, got {len(secret)}") - if len(initiator_id) != 32: - raise ValueError(f"Initiator ID must be 32 bytes, got {len(initiator_id)}") - if len(recipient_id) != 32: - raise ValueError(f"Recipient ID must be 32 bytes, got {len(recipient_id)}") - # HKDF-Extract: PRK = HMAC-SHA256(salt, IKM). # # Using challenge_data as salt binds session keys to the specific WHOAREYOU. @@ -111,7 +104,7 @@ def derive_keys( def derive_keys_from_pubkey( local_private_key: Bytes32, - remote_public_key: bytes, + remote_public_key: Bytes33 | Bytes65, local_node_id: Bytes32, remote_node_id: Bytes32, challenge_data: bytes, @@ -125,7 +118,7 @@ def derive_keys_from_pubkey( Args: local_private_key: Our 32-byte secp256k1 private key. - remote_public_key: Peer's compressed public key. + remote_public_key: Peer's compressed (33-byte) or uncompressed (65-byte) public key. local_node_id: Our 32-byte node ID. remote_node_id: Peer's 32-byte node ID. challenge_data: WHOAREYOU packet data (masking-iv || static-header || authdata). @@ -154,7 +147,7 @@ def derive_keys_from_pubkey( return recipient_key, initiator_key -def compute_node_id(public_key_bytes: bytes) -> Bytes32: +def compute_node_id(public_key_bytes: Bytes33 | Bytes65) -> Bytes32: """ Compute node ID from public key. diff --git a/src/lean_spec/subspecs/networking/discovery/messages.py b/src/lean_spec/subspecs/networking/discovery/messages.py index a34cdd7b..b7feeda3 100644 --- a/src/lean_spec/subspecs/networking/discovery/messages.py +++ b/src/lean_spec/subspecs/networking/discovery/messages.py @@ -77,11 +77,12 @@ class Nonce(BaseBytes): LENGTH: ClassVar[int] = 12 -Distance = Uint16 -"""Log2 distance (0-256). Distance 0 returns the node's own ENR.""" +class Distance(Uint16): + """Log2 distance (0-256). Distance 0 returns the node's own ENR.""" -Port = Uint16 -"""UDP port number (0-65535).""" + +class Port(Uint16): + """UDP port number (0-65535).""" class PacketFlag(IntEnum): diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py index 45896c7e..18b73dab 100644 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ b/src/lean_spec/subspecs/networking/discovery/packet.py @@ -34,15 +34,13 @@ from dataclasses import dataclass from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16 +from lean_spec.types import Bytes12, Bytes16, Bytes33, Bytes64 from .config import MAX_PACKET_SIZE, MIN_PACKET_SIZE from .crypto import ( AES_KEY_SIZE, - COMPRESSED_PUBKEY_SIZE, CTR_IV_SIZE, GCM_NONCE_SIZE, - ID_SIGNATURE_SIZE, aes_ctr_decrypt, aes_ctr_encrypt, aes_gcm_decrypt, @@ -109,10 +107,10 @@ class HandshakeAuthdata: eph_key_size: int """Size of ephemeral public key. 33 for compressed secp256k1.""" - id_signature: bytes + id_signature: Bytes64 """ID nonce signature proving identity ownership.""" - eph_pubkey: bytes + eph_pubkey: Bytes33 """Ephemeral public key for ECDH.""" record: bytes | None @@ -122,10 +120,10 @@ class HandshakeAuthdata: def encode_packet( dest_node_id: NodeId, flag: PacketFlag, - nonce: bytes, + nonce: Nonce, authdata: bytes, message: bytes, - encryption_key: bytes | None = None, + encryption_key: Bytes16 | None = None, masking_iv: Bytes16 | None = None, ) -> bytes: """ @@ -144,11 +142,6 @@ def encode_packet( Returns: Complete encoded packet ready for UDP transmission. """ - if len(dest_node_id) != 32: - raise ValueError(f"Destination node ID must be 32 bytes, got {len(dest_node_id)}") - if len(nonce) != GCM_NONCE_SIZE: - raise ValueError(f"Nonce must be {GCM_NONCE_SIZE} bytes, got {len(nonce)}") - if masking_iv is None: # Fresh random IV for header masking. # @@ -181,9 +174,7 @@ def encode_packet( # The AAD binds the plaintext header to the encrypted message. # The recipient reconstructs this from the decoded header. message_ad = bytes(masking_iv) + header - encrypted_message = aes_gcm_encrypt( - Bytes16(encryption_key), Bytes12(nonce), message, message_ad - ) + encrypted_message = aes_gcm_encrypt(encryption_key, Bytes12(nonce), message, message_ad) # Assemble packet. packet = bytes(masking_iv) + masked_header + encrypted_message @@ -294,10 +285,10 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: raise ValueError(f"Handshake authdata truncated: {len(authdata)} < {expected_min}") offset = HANDSHAKE_HEADER_SIZE - id_signature = authdata[offset : offset + sig_size] + id_signature = Bytes64(authdata[offset : offset + sig_size]) offset += sig_size - eph_pubkey = authdata[offset : offset + eph_key_size] + eph_pubkey = Bytes33(authdata[offset : offset + eph_key_size]) offset += eph_key_size # Remaining bytes are the RLP-encoded ENR, included when the recipient's @@ -315,8 +306,8 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: def decrypt_message( - encryption_key: bytes, - nonce: bytes, + encryption_key: Bytes16, + nonce: Nonce, ciphertext: bytes, message_ad: bytes, ) -> bytes: @@ -332,27 +323,23 @@ def decrypt_message( Returns: Decrypted message plaintext. """ - return aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, message_ad) + return aes_gcm_decrypt(encryption_key, Bytes12(nonce), ciphertext, message_ad) def encode_message_authdata(src_id: NodeId) -> bytes: """Encode MESSAGE packet authdata.""" - if len(src_id) != 32: - raise ValueError(f"Source ID must be 32 bytes, got {len(src_id)}") return src_id -def encode_whoareyou_authdata(id_nonce: bytes, enr_seq: SeqNumber) -> bytes: +def encode_whoareyou_authdata(id_nonce: IdNonce, enr_seq: SeqNumber) -> bytes: """Encode WHOAREYOU packet authdata.""" - if len(id_nonce) != 16: - raise ValueError(f"ID nonce must be 16 bytes, got {len(id_nonce)}") return id_nonce + struct.pack(">Q", enr_seq) def encode_handshake_authdata( src_id: NodeId, - id_signature: bytes, - eph_pubkey: bytes, + id_signature: Bytes64, + eph_pubkey: Bytes33, record: bytes | None = None, ) -> bytes: """ @@ -367,15 +354,6 @@ def encode_handshake_authdata( Returns: Encoded authdata bytes. """ - if len(src_id) != 32: - raise ValueError(f"Source ID must be 32 bytes, got {len(src_id)}") - if len(id_signature) != ID_SIGNATURE_SIZE: - raise ValueError(f"Signature must be {ID_SIGNATURE_SIZE} bytes, got {len(id_signature)}") - if len(eph_pubkey) != COMPRESSED_PUBKEY_SIZE: - raise ValueError( - f"Ephemeral pubkey must be {COMPRESSED_PUBKEY_SIZE} bytes, got {len(eph_pubkey)}" - ) - authdata = src_id + bytes([len(id_signature), len(eph_pubkey)]) + id_signature + eph_pubkey if record is not None: @@ -394,7 +372,7 @@ def generate_id_nonce() -> IdNonce: return IdNonce(os.urandom(16)) -def encode_static_header(flag: PacketFlag, nonce: bytes, authdata_size: int) -> bytes: +def encode_static_header(flag: PacketFlag, nonce: Nonce, authdata_size: int) -> bytes: """Encode the 23-byte static header.""" return ( PROTOCOL_ID diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py index c3a0280f..8e17d3f5 100644 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ b/src/lean_spec/subspecs/networking/discovery/service.py @@ -33,6 +33,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.types import Bytes32, Bytes33 from lean_spec.types.uint import Uint8 from .codec import DiscoveryMessage @@ -90,7 +91,7 @@ class DiscoveryService: def __init__( self, local_enr: ENR, - private_key: bytes, + private_key: Bytes32, config: DiscoveryConfig | None = None, bootnodes: list[ENR] | None = None, ): @@ -103,7 +104,7 @@ def __init__( # Compute our node ID from public key. if local_enr.public_key is None: raise ValueError("Local ENR must have a public key") - self._local_node_id = NodeId(compute_node_id(bytes(local_enr.public_key))) + self._local_node_id = NodeId(compute_node_id(Bytes33(local_enr.public_key))) # Initialize routing table. self._routing_table = RoutingTable(local_id=NodeId(self._local_node_id)) diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py index 6708d556..f9fdffd9 100644 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ b/src/lean_spec/subspecs/networking/discovery/session.py @@ -25,6 +25,7 @@ from threading import Lock from lean_spec.subspecs.networking.types import NodeId +from lean_spec.types import Bytes16 from .config import BOND_EXPIRY_SECS from .messages import Port @@ -51,10 +52,10 @@ class Session: node_id: NodeId """Peer's 32-byte node ID.""" - send_key: bytes + send_key: Bytes16 """16-byte key for encrypting messages to this peer.""" - recv_key: bytes + recv_key: Bytes16 """16-byte key for decrypting messages from this peer.""" created_at: float @@ -134,8 +135,8 @@ def get(self, node_id: NodeId, ip: str = "", port: Port = _DEFAULT_PORT) -> Sess def create( self, node_id: NodeId, - send_key: bytes, - recv_key: bytes, + send_key: Bytes16, + recv_key: Bytes16, is_initiator: bool, ip: str = "", port: Port = _DEFAULT_PORT, @@ -157,13 +158,6 @@ def create( Returns: The newly created session. """ - if len(node_id) != 32: - raise ValueError(f"Node ID must be 32 bytes, got {len(node_id)}") - if len(send_key) != 16: - raise ValueError(f"Send key must be 16 bytes, got {len(send_key)}") - if len(recv_key) != 16: - raise ValueError(f"Recv key must be 16 bytes, got {len(recv_key)}") - key: SessionKey = (node_id, ip, port) now = time.time() session = Session( @@ -257,7 +251,7 @@ def _evict_oldest(self) -> None: del self.sessions[oldest_key] -@dataclass +@dataclass(slots=True) class BondCache: """ Cache tracking which nodes we have successfully bonded with. diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index b9fe97b2..b9d9f449 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -25,7 +25,7 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16 +from lean_spec.types import Bytes16, Bytes32, Bytes33 from .codec import ( DiscoveryMessage, @@ -78,7 +78,7 @@ class PendingRequest: sent_at: float """Timestamp when request was sent.""" - nonce: bytes + nonce: Nonce """Packet nonce (needed for WHOAREYOU handling).""" message: DiscoveryMessage @@ -105,7 +105,7 @@ class PendingMultiRequest: sent_at: float """Timestamp when request was sent.""" - nonce: bytes + nonce: Nonce """Packet nonce (needed for WHOAREYOU handling).""" message: DiscoveryMessage @@ -129,7 +129,7 @@ class DiscoveryProtocol(asyncio.DatagramProtocol): transport_handler: Parent transport for packet handling. """ - def __init__(self, transport_handler: DiscoveryTransport): + def __init__(self, transport_handler: DiscoveryTransport) -> None: """Initialize protocol handler.""" self._handler = transport_handler self._transport: asyncio.DatagramTransport | None = None @@ -172,7 +172,7 @@ class DiscoveryTransport: def __init__( self, local_node_id: NodeId, - local_private_key: bytes, + local_private_key: Bytes32, local_enr: ENR, config: DiscoveryConfig | None = None, ): @@ -390,7 +390,7 @@ async def _send_multi_response_request( request_id=request_id_bytes, dest_node_id=dest_node_id, sent_at=loop.time(), - nonce=bytes(nonce), + nonce=nonce, message=message, response_queue=response_queue, expected_total=None, @@ -498,7 +498,7 @@ async def _send_request( request_id=request_id_bytes, dest_node_id=dest_node_id, sent_at=loop.time(), - nonce=bytes(nonce), + nonce=nonce, message=message, future=future, ) @@ -547,7 +547,7 @@ def _build_message_packet( return encode_packet( dest_node_id=dest_node_id, flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=session.send_key, @@ -565,11 +565,11 @@ def _build_message_packet( # This approach avoids the need for session negotiation # before sending the first message. self._handshake_manager.start_handshake(dest_node_id) - dummy_key = os.urandom(16) + dummy_key = Bytes16(os.urandom(16)) return encode_packet( dest_node_id=dest_node_id, flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=dummy_key, @@ -626,7 +626,7 @@ async def _handle_whoareyou( # This links the challenge to the specific request that failed. pending = None for p in self._pending_requests.values(): - if p.nonce == bytes(header.nonce): + if p.nonce == header.nonce: pending = p break @@ -647,7 +647,7 @@ async def _handle_whoareyou( # We use the unmasked header, which we can reconstruct from the decoded values. masking_iv = raw_packet[:16] static_header = encode_static_header( - PacketFlag.WHOAREYOU, bytes(header.nonce), len(header.authdata) + PacketFlag.WHOAREYOU, header.nonce, len(header.authdata) ) challenge_data = masking_iv + static_header + header.authdata @@ -660,7 +660,7 @@ async def _handle_whoareyou( logger.debug("No ENR for %s, cannot complete handshake", remote_node_id.hex()[:16]) return - remote_pubkey = bytes(remote_enr.public_key) + remote_pubkey = Bytes33(remote_enr.public_key) # Build and send the HANDSHAKE response. try: @@ -685,7 +685,7 @@ async def _handle_whoareyou( packet = encode_packet( dest_node_id=remote_node_id, flag=PacketFlag.HANDSHAKE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=send_key, @@ -725,7 +725,7 @@ async def _handle_handshake( if len(message_bytes) > 0: plaintext = decrypt_message( encryption_key=result.session.recv_key, - nonce=bytes(header.nonce), + nonce=header.nonce, ciphertext=message_bytes, message_ad=message_ad, ) @@ -763,7 +763,7 @@ async def _handle_message( try: plaintext = decrypt_message( encryption_key=session.recv_key, - nonce=bytes(header.nonce), + nonce=header.nonce, ciphertext=message_bytes, message_ad=message_ad, ) @@ -828,11 +828,11 @@ async def _send_whoareyou( # # This IV is part of the challenge_data used for key derivation. # Both sides must use identical challenge_data to derive matching keys. - masking_iv = os.urandom(16) + masking_iv = Bytes16(os.urandom(16)) id_nonce, authdata, nonce, challenge_data = self._handshake_manager.create_whoareyou( remote_node_id=remote_node_id, - request_nonce=bytes(request_nonce), + request_nonce=request_nonce, remote_enr_seq=remote_enr_seq, masking_iv=masking_iv, ) @@ -844,7 +844,7 @@ async def _send_whoareyou( authdata=authdata, message=b"", encryption_key=None, - masking_iv=Bytes16(masking_iv), + masking_iv=masking_iv, ) self._transport.sendto(packet, addr) @@ -893,7 +893,7 @@ async def send_response( packet = encode_packet( dest_node_id=dest_node_id, flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=session.send_key, diff --git a/src/lean_spec/subspecs/networking/enr/enr.py b/src/lean_spec/subspecs/networking/enr/enr.py index 3afc188c..b0e6879c 100644 --- a/src/lean_spec/subspecs/networking/enr/enr.py +++ b/src/lean_spec/subspecs/networking/enr/enr.py @@ -50,16 +50,14 @@ encode_dss_signature, ) -from lean_spec.subspecs.networking.types import Multiaddr, NodeId, SeqNumber +from lean_spec.subspecs.networking.types import ForkDigest, Multiaddr, NodeId, SeqNumber, Version from lean_spec.types import ( - Bytes32, Bytes33, Bytes64, StrictBaseModel, Uint64, rlp, ) -from lean_spec.types.byte_arrays import Bytes4 from . import keys from .eth2 import AttestationSubnets, Eth2Data, SyncCommitteeSubnets @@ -159,8 +157,8 @@ def eth2_data(self) -> Eth2Data | None: eth2_bytes = self.get(keys.ETH2) if eth2_bytes and len(eth2_bytes) >= 16: return Eth2Data( - fork_digest=Bytes4(eth2_bytes[0:4]), - next_fork_version=Bytes4(eth2_bytes[4:8]), + fork_digest=ForkDigest(eth2_bytes[0:4]), + next_fork_version=Version(eth2_bytes[4:8]), next_fork_epoch=Uint64(int.from_bytes(eth2_bytes[8:16], "little")), ) return None @@ -211,7 +209,7 @@ def _build_content_items(self) -> list[bytes]: Returns [seq, k1, v1, k2, v2, ...] with keys sorted lexicographically. """ - sorted_keys = sorted(self.pairs.keys()) + sorted_keys = sorted(self.pairs) # Sequence number: minimal big-endian, empty bytes for zero. seq_bytes = self.seq.to_bytes(8, "big").lstrip(b"\x00") or b"" @@ -292,7 +290,7 @@ def compute_node_id(self) -> NodeId | None: # Hash the 64-byte x||y (excluding 0x04 prefix). k = keccak.new(digest_bits=256) k.update(uncompressed[1:]) - return Bytes32(k.digest()) + return NodeId(k.digest()) except (ValueError, TypeError): return None diff --git a/src/lean_spec/subspecs/networking/enr/eth2.py b/src/lean_spec/subspecs/networking/enr/eth2.py index 5df7ab12..da835d81 100644 --- a/src/lean_spec/subspecs/networking/enr/eth2.py +++ b/src/lean_spec/subspecs/networking/enr/eth2.py @@ -135,9 +135,9 @@ def is_subscribed(self, subnet_id: int) -> bool: raise ValueError(f"Sync subnet ID must be 0-3, got {subnet_id}") return bool(self.data[subnet_id]) - def subscribed_subnets(self) -> list[int]: + def subscribed_subnets(self) -> list[SubnetId]: """List of subscribed sync subnet IDs.""" - return [i for i in range(self.LENGTH) if self.data[i]] + return [SubnetId(i) for i in range(self.LENGTH) if self.data[i]] def subscription_count(self) -> int: """Number of subscribed sync subnets.""" diff --git a/src/lean_spec/subspecs/networking/enr/keys.py b/src/lean_spec/subspecs/networking/enr/keys.py index 6b83764f..5444a881 100644 --- a/src/lean_spec/subspecs/networking/enr/keys.py +++ b/src/lean_spec/subspecs/networking/enr/keys.py @@ -11,7 +11,7 @@ from typing import Final EnrKey = str -"""Type alias for ENR keys (can be any string/bytes per EIP-778)""" +"""ENR key identifier (any string/bytes per EIP-778).""" # EIP-778 Standard Keys ID: Final[EnrKey] = "id" diff --git a/src/lean_spec/subspecs/networking/gossipsub/behavior.py b/src/lean_spec/subspecs/networking/gossipsub/behavior.py index 27196d23..ef613220 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/behavior.py +++ b/src/lean_spec/subspecs/networking/gossipsub/behavior.py @@ -86,7 +86,7 @@ from lean_spec.subspecs.networking.transport import PeerId from lean_spec.subspecs.networking.transport.quic.stream_adapter import QuicStreamAdapter from lean_spec.subspecs.networking.varint import decode_varint, encode_varint -from lean_spec.types import Bytes20, Uint16 +from lean_spec.types import Uint16 logger = logging.getLogger(__name__) @@ -676,7 +676,7 @@ async def _handle_ihave(self, peer_id: PeerId, ihave: ControlIHave) -> None: # Message IDs must be exactly 20 bytes (SHA256 truncated to 160 bits). if len(msg_id) != 20: continue - msg_id_typed = Bytes20(msg_id) + msg_id_typed = MessageId(msg_id) if not self.seen_cache.has(msg_id_typed) and not self.message_cache.has(msg_id_typed): wanted.append(msg_id) @@ -695,7 +695,7 @@ async def _handle_iwant(self, peer_id: PeerId, iwant: ControlIWant) -> None: for msg_id in iwant.message_ids: if len(msg_id) != 20: continue - msg_id_typed = Bytes20(msg_id) + msg_id_typed = MessageId(msg_id) cached = self.message_cache.get(msg_id_typed) if cached: messages.append(Message(topic=cached.topic.decode("utf-8"), data=cached.raw_data)) @@ -718,7 +718,7 @@ def _handle_idontwant(self, peer_id: PeerId, idontwant: ControlIDontWant) -> Non for msg_id in idontwant.message_ids: if len(msg_id) != 20: continue - state.dont_want_ids.add(Bytes20(msg_id)) + state.dont_want_ids.add(MessageId(msg_id)) async def _heartbeat_loop(self) -> None: """Background heartbeat for mesh maintenance.""" diff --git a/src/lean_spec/subspecs/networking/gossipsub/message.py b/src/lean_spec/subspecs/networking/gossipsub/message.py index 952a3fa9..2037432d 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/message.py +++ b/src/lean_spec/subspecs/networking/gossipsub/message.py @@ -62,7 +62,6 @@ MESSAGE_DOMAIN_INVALID_SNAPPY, MESSAGE_DOMAIN_VALID_SNAPPY, ) -from lean_spec.types import Bytes20 from .types import MessageId @@ -172,7 +171,7 @@ def compute_id( preimage = bytes(domain) + len(topic).to_bytes(8, "little") + topic + data_for_hash - return Bytes20(hashlib.sha256(preimage).digest()[:20]) + return MessageId(hashlib.sha256(preimage).digest()[:20]) def __hash__(self) -> int: """Hash based on message ID. diff --git a/src/lean_spec/subspecs/networking/gossipsub/types.py b/src/lean_spec/subspecs/networking/gossipsub/types.py index 72361091..4c5aa0d5 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/types.py +++ b/src/lean_spec/subspecs/networking/gossipsub/types.py @@ -4,15 +4,16 @@ from lean_spec.types import Bytes20 -type MessageId = Bytes20 -"""20-byte message identifier. -Computed from message contents using SHA256:: +class MessageId(Bytes20): + """20-byte message identifier. - SHA256(domain + uint64_le(len(topic)) + topic + data)[:20] + Computed from message contents using SHA256:: -The domain byte distinguishes valid/invalid snappy compression. -""" + SHA256(domain + uint64_le(len(topic)) + topic + data)[:20] + + The domain byte distinguishes valid/invalid snappy compression. + """ type TopicId = str diff --git a/src/lean_spec/subspecs/networking/peer.py b/src/lean_spec/subspecs/networking/peer.py index ac309073..13e16d50 100644 --- a/src/lean_spec/subspecs/networking/peer.py +++ b/src/lean_spec/subspecs/networking/peer.py @@ -7,14 +7,14 @@ from typing import TYPE_CHECKING from .transport import PeerId -from .types import ConnectionState, Direction, Multiaddr +from .types import ConnectionState, Direction, ForkDigest, Multiaddr if TYPE_CHECKING: from .enr import ENR from .reqresp import Status -@dataclass +@dataclass(slots=True) class PeerInfo: """ Information about a known peer. @@ -59,7 +59,7 @@ def update_last_seen(self) -> None: self.last_seen = time() @property - def fork_digest(self) -> bytes | None: + def fork_digest(self) -> ForkDigest | None: """ Get the peer's fork_digest from cached ENR. @@ -71,4 +71,4 @@ def fork_digest(self) -> bytes | None: eth2_data = self.enr.eth2_data if eth2_data is None: return None - return bytes(eth2_data.fork_digest) + return eth2_data.fork_digest diff --git a/src/lean_spec/subspecs/networking/service/service.py b/src/lean_spec/subspecs/networking/service/service.py index 28fa9bbf..b06e2699 100644 --- a/src/lean_spec/subspecs/networking/service/service.py +++ b/src/lean_spec/subspecs/networking/service/service.py @@ -232,7 +232,7 @@ async def publish_attestation( compressed = frame_compress(ssz_bytes) await self.event_source.publish(str(topic), compressed) - logger.debug("Published attestation for slot %s", attestation.message.slot) + logger.debug("Published attestation for slot %s", attestation.data.slot) async def publish_aggregated_attestation( self, signed_attestation: SignedAggregatedAttestation diff --git a/src/lean_spec/subspecs/networking/transport/identity/keypair.py b/src/lean_spec/subspecs/networking/transport/identity/keypair.py index 7b23c3b1..4b09b838 100644 --- a/src/lean_spec/subspecs/networking/transport/identity/keypair.py +++ b/src/lean_spec/subspecs/networking/transport/identity/keypair.py @@ -16,6 +16,8 @@ from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec +from lean_spec.types import Bytes32, Bytes33 + from ..peer_id import KeyType, PeerId, PublicKeyProto __all__ = [ @@ -28,7 +30,7 @@ class IdentityKeypair: """ secp256k1 keypair for libp2p identity. - Used to derive PeerId and sign identity proofs during Noise handshake. + Used to derive PeerId and sign identity proofs during QUIC TLS handshake. Attributes: private_key: The secp256k1 private key. @@ -48,7 +50,7 @@ def generate(cls) -> IdentityKeypair: return cls(private_key=private_key) @classmethod - def from_bytes(cls, data: bytes) -> IdentityKeypair: + def from_bytes(cls, data: Bytes32) -> IdentityKeypair: """ Load keypair from raw private key bytes. @@ -57,20 +59,14 @@ def from_bytes(cls, data: bytes) -> IdentityKeypair: Returns: Identity keypair. - - Raises: - ValueError: If data is not a valid secp256k1 private key. """ - if len(data) != 32: - raise ValueError(f"Expected 32 bytes, got {len(data)}") - private_key = ec.derive_private_key( int.from_bytes(data, "big"), ec.SECP256K1(), ) return cls(private_key=private_key) - def private_key_bytes(self) -> bytes: + def private_key_bytes(self) -> Bytes32: """ Return the raw 32-byte private key. @@ -78,9 +74,9 @@ def private_key_bytes(self) -> bytes: 32-byte private key scalar. """ private_numbers = self.private_key.private_numbers() - return private_numbers.private_value.to_bytes(32, "big") + return Bytes32(private_numbers.private_value.to_bytes(32, "big")) - def public_key_bytes(self) -> bytes: + def public_key_bytes(self) -> Bytes33: """ Return the compressed secp256k1 public key (33 bytes). @@ -91,9 +87,11 @@ def public_key_bytes(self) -> bytes: 33-byte compressed public key. """ public_key = self.private_key.public_key() - return public_key.public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + return Bytes33( + public_key.public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) def sign(self, message: bytes) -> bytes: @@ -128,7 +126,7 @@ def to_peer_id(self) -> PeerId: def verify_signature( - public_key_bytes: bytes, + public_key_bytes: Bytes33, message: bytes, signature: bytes, ) -> bool: diff --git a/src/lean_spec/subspecs/networking/transport/identity/signature.py b/src/lean_spec/subspecs/networking/transport/identity/signature.py index 59fe2eb9..5c98c501 100644 --- a/src/lean_spec/subspecs/networking/transport/identity/signature.py +++ b/src/lean_spec/subspecs/networking/transport/identity/signature.py @@ -1,11 +1,11 @@ """ -Identity proof for libp2p-noise handshake. +Identity proof for libp2p peer authentication. -During the Noise handshake, peers must prove they own their claimed -libp2p identity key by signing the Noise static public key. +During the QUIC TLS handshake, peers must prove they own their claimed +libp2p identity key by signing the TLS public key. -The signature format follows the libp2p-noise specification: - message = "noise-libp2p-static-key:" || noise_public_key +The signature format follows the libp2p specification: + message = "noise-libp2p-static-key:" || public_key signature = ECDSA-SHA256(identity_private_key, message) References: @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING, Final +from lean_spec.types import Bytes32, Bytes33 + from .keypair import verify_signature if TYPE_CHECKING: @@ -34,43 +36,42 @@ def create_identity_proof( identity_key: IdentityKeypair, - noise_public_key: bytes, + public_key: Bytes32, ) -> bytes: """ - Create identity proof signature for Noise handshake. + Create identity proof signature for peer authentication. Proves that the owner of the identity key (secp256k1) also controls - the Noise static key (X25519). This binding prevents man-in-the-middle - attacks where an attacker substitutes their Noise key. + the TLS static key. This binding prevents man-in-the-middle attacks. Args: identity_key: The secp256k1 identity keypair. - noise_public_key: The 32-byte X25519 Noise static public key. + public_key: The 32-byte TLS public key. Returns: DER-encoded ECDSA signature. """ - message = NOISE_IDENTITY_PREFIX + noise_public_key + message = NOISE_IDENTITY_PREFIX + public_key return identity_key.sign(message) def verify_identity_proof( - identity_public_key: bytes, - noise_public_key: bytes, + identity_public_key: Bytes33, + public_key: Bytes32, signature: bytes, ) -> bool: """ Verify identity proof signature. - Called during Noise handshake to verify the remote peer's identity claim. + Called during QUIC TLS handshake to verify the remote peer's identity claim. Args: identity_public_key: 33-byte compressed secp256k1 public key. - noise_public_key: 32-byte X25519 Noise static public key. + public_key: 32-byte TLS public key. signature: DER-encoded ECDSA signature. Returns: True if the signature is valid, False otherwise. """ - message = NOISE_IDENTITY_PREFIX + noise_public_key + message = NOISE_IDENTITY_PREFIX + public_key return verify_signature(identity_public_key, message, signature) diff --git a/src/lean_spec/subspecs/networking/transport/peer_id.py b/src/lean_spec/subspecs/networking/transport/peer_id.py index 05e6e277..2a1f4848 100644 --- a/src/lean_spec/subspecs/networking/transport/peer_id.py +++ b/src/lean_spec/subspecs/networking/transport/peer_id.py @@ -38,6 +38,7 @@ from typing import Final from lean_spec.subspecs.networking import varint +from lean_spec.types import Bytes33 __all__ = [ # Main types @@ -408,7 +409,7 @@ def from_public_key(cls, public_key: PublicKeyProto) -> PeerId: return cls(multihash=mh.encode()) @classmethod - def from_secp256k1(cls, public_key_bytes: bytes) -> PeerId: + def from_secp256k1(cls, public_key_bytes: Bytes33) -> PeerId: """ Derive PeerId from a secp256k1 compressed public key. @@ -421,14 +422,6 @@ def from_secp256k1(cls, public_key_bytes: bytes) -> PeerId: Returns: Derived PeerId (starts with "16Uiu2..." for secp256k1). - - Raises: - ValueError: If public key is not 33 bytes. """ - if len(public_key_bytes) != 33: - raise ValueError( - f"secp256k1 compressed key must be 33 bytes, got {len(public_key_bytes)}" - ) - proto = PublicKeyProto(key_type=KeyType.SECP256K1, key_data=public_key_bytes) return cls.from_public_key(proto) diff --git a/src/lean_spec/subspecs/networking/types.py b/src/lean_spec/subspecs/networking/types.py index a7fb86b3..208b5367 100644 --- a/src/lean_spec/subspecs/networking/types.py +++ b/src/lean_spec/subspecs/networking/types.py @@ -5,23 +5,27 @@ from lean_spec.types import Uint64 from lean_spec.types.byte_arrays import Bytes1, Bytes4, Bytes32 -DomainType = Bytes1 -"""1-byte domain for message-id isolation in Gossipsub. -The domain is a single byte prepended to the message hash to compute the gossip message ID. +class DomainType(Bytes1): + """1-byte domain for message-id isolation in Gossipsub. -- Valid messages use 0x01, -- Invalid messages use 0x00. -""" + The domain is a single byte prepended to the message hash to compute the gossip message ID. -NodeId = Bytes32 -"""32-byte node identifier for Discovery v5, derived from ``keccak256(pubkey)``.""" + - Valid messages use 0x01, + - Invalid messages use 0x00. + """ + + +class NodeId(Bytes32): + """32-byte node identifier for Discovery v5, derived from ``keccak256(pubkey)``.""" + + +class ForkDigest(Bytes4): + """4-byte fork identifier ensuring network isolation between forks.""" -ForkDigest = Bytes4 -"""4-byte fork identifier ensuring network isolation between forks.""" -Version = Bytes4 -"""4-byte fork version number (e.g., 0x01000000 for Phase0).""" +class Version(Bytes4): + """4-byte fork version number (e.g., 0x01000000 for Phase0).""" class SeqNumber(Uint64): diff --git a/src/lean_spec/subspecs/node/node.py b/src/lean_spec/subspecs/node/node.py index 3372b02c..7466fb6e 100644 --- a/src/lean_spec/subspecs/node/node.py +++ b/src/lean_spec/subspecs/node/node.py @@ -205,7 +205,7 @@ def from_genesis(cls, config: NodeConfig) -> Node: # Initialize forkchoice store. # # Genesis block is both justified and finalized. - store = Store.get_forkchoice_store(state, block, validator_id) + store = state.to_forkchoice_store(block, validator_id) # Persist genesis to database if available. if database is not None: diff --git a/src/lean_spec/subspecs/ssz/hash.py b/src/lean_spec/subspecs/ssz/hash.py index b4baaac8..d7b6b292 100644 --- a/src/lean_spec/subspecs/ssz/hash.py +++ b/src/lean_spec/subspecs/ssz/hash.py @@ -149,9 +149,7 @@ def _htr_list(value: SSZList) -> Bytes32: @hash_tree_root.register def _htr_container(value: Container) -> Bytes32: # Preserve declared field order from the Pydantic model - return merkleize( - [hash_tree_root(getattr(value, fname)) for fname in type(value).model_fields.keys()] - ) + return merkleize([hash_tree_root(getattr(value, fname)) for fname in type(value).model_fields]) @hash_tree_root.register diff --git a/src/lean_spec/subspecs/sync/block_cache.py b/src/lean_spec/subspecs/sync/block_cache.py index 8ce51871..5728bcad 100644 --- a/src/lean_spec/subspecs/sync/block_cache.py +++ b/src/lean_spec/subspecs/sync/block_cache.py @@ -345,7 +345,7 @@ def get_children(self, parent_root: Bytes32) -> list[PendingBlock]: # parent P, we must process A before B. return sorted(children, key=lambda p: p.slot) - def get_processable(self, store: "Store") -> list[PendingBlock]: + def get_processable(self, store: Store) -> list[PendingBlock]: """ Get blocks whose parents exist in the Store. diff --git a/src/lean_spec/subspecs/sync/checkpoint_sync.py b/src/lean_spec/subspecs/sync/checkpoint_sync.py index 8e418804..18fa9b26 100644 --- a/src/lean_spec/subspecs/sync/checkpoint_sync.py +++ b/src/lean_spec/subspecs/sync/checkpoint_sync.py @@ -44,7 +44,7 @@ class CheckpointSyncError(Exception): """ -async def fetch_finalized_state(url: str, state_class: type["State"]) -> "State": +async def fetch_finalized_state(url: str, state_class: type[State]) -> State: """ Fetch finalized state from a node via checkpoint sync. @@ -102,7 +102,7 @@ async def fetch_finalized_state(url: str, state_class: type["State"]) -> "State" raise CheckpointSyncError(f"Failed to fetch state: {e}") from e -def verify_checkpoint_state(state: "State") -> bool: +def verify_checkpoint_state(state: State) -> bool: """ Verify that a checkpoint state is structurally valid. diff --git a/src/lean_spec/subspecs/sync/head_sync.py b/src/lean_spec/subspecs/sync/head_sync.py index a04cf282..198564ee 100644 --- a/src/lean_spec/subspecs/sync/head_sync.py +++ b/src/lean_spec/subspecs/sync/head_sync.py @@ -45,8 +45,8 @@ from __future__ import annotations import logging +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Callable from lean_spec.subspecs.containers import SignedBlockWithAttestation from lean_spec.subspecs.forkchoice import Store diff --git a/src/lean_spec/subspecs/sync/states.py b/src/lean_spec/subspecs/sync/states.py index 20fb1317..2e672a92 100644 --- a/src/lean_spec/subspecs/sync/states.py +++ b/src/lean_spec/subspecs/sync/states.py @@ -98,7 +98,7 @@ class SyncState(Enum): - Falls back to SYNCING if gaps appear """ - def can_transition_to(self, target: "SyncState") -> bool: + def can_transition_to(self, target: SyncState) -> bool: """ Check if transition to target state is valid. diff --git a/src/lean_spec/subspecs/validator/service.py b/src/lean_spec/subspecs/validator/service.py index dfa1a685..6333e36b 100644 --- a/src/lean_spec/subspecs/validator/service.py +++ b/src/lean_spec/subspecs/validator/service.py @@ -492,7 +492,7 @@ def _sign_attestation( return SignedAttestation( validator_id=validator_index, - message=attestation_data, + data=attestation_data, signature=signature, ) diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py index c9ad8060..e7a5a795 100644 --- a/src/lean_spec/subspecs/xmss/containers.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -79,9 +79,9 @@ def _serialize_as_bytes(self) -> str: def verify( self, public_key: PublicKey, - slot: "Slot", - message: "Bytes32", - scheme: "GeneralizedXmssScheme", + slot: Slot, + message: Bytes32, + scheme: GeneralizedXmssScheme, ) -> bool: """ Verify the signature using XMSS verification algorithm. diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 2c1bbe25..831948b8 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -62,7 +62,7 @@ class GeneralizedXmssScheme(StrictBaseModel): """Random data generator for key generation.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "GeneralizedXmssScheme": + def _validate_strict_types(self) -> GeneralizedXmssScheme: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types( self, diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index e6c823f2..13deb074 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -65,7 +65,7 @@ class MessageHasher(StrictBaseModel): """Poseidon hash engine.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "MessageHasher": + def _validate_strict_types(self) -> MessageHasher: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) return self diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 92f72cc6..312798b1 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -47,7 +47,7 @@ class PoseidonXmss(StrictBaseModel): """Poseidon2 parameters for 24-width permutation.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "PoseidonXmss": + def _validate_strict_types(self) -> PoseidonXmss: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, params16=Poseidon2Params, params24=Poseidon2Params) return self diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 8dd455a9..8b9a8c64 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -110,7 +110,7 @@ class Prf(StrictBaseModel): """Configuration parameters for the PRF.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "Prf": + def _validate_strict_types(self) -> Prf: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, config=XmssConfig) return self diff --git a/src/lean_spec/subspecs/xmss/subtree.py b/src/lean_spec/subspecs/xmss/subtree.py index c1f1f6a2..3b3a7aa5 100644 --- a/src/lean_spec/subspecs/xmss/subtree.py +++ b/src/lean_spec/subspecs/xmss/subtree.py @@ -325,14 +325,14 @@ def new_bottom_tree( @classmethod def from_prf_key( cls, - prf: "Prf", - hasher: "TweakHasher", - rand: "Rand", - config: "XmssConfig", + prf: Prf, + hasher: TweakHasher, + rand: Rand, + config: XmssConfig, prf_key: PRFKey, bottom_tree_index: Uint64, parameter: Parameter, - ) -> "HashSubTree": + ) -> HashSubTree: """ Generates a single bottom tree on-demand from the PRF key. @@ -544,7 +544,7 @@ def combined_path( def verify_path( - hasher: "TweakHasher", + hasher: TweakHasher, parameter: Parameter, root: HashDigestVector, position: Uint64, diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index ec8b0538..3657f051 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -87,7 +87,7 @@ class TweakHasher(StrictBaseModel): """Poseidon permutation instance for hashing.""" @model_validator(mode="after") - def _validate_strict_types(self) -> "TweakHasher": + def _validate_strict_types(self) -> TweakHasher: """Reject subclasses to prevent type confusion attacks.""" enforce_strict_types(self, config=XmssConfig, poseidon=PoseidonXmss) return self diff --git a/tests/consensus/devnet/ssz/test_consensus_containers.py b/tests/consensus/devnet/ssz/test_consensus_containers.py index 9a1ba7aa..5f390748 100644 --- a/tests/consensus/devnet/ssz/test_consensus_containers.py +++ b/tests/consensus/devnet/ssz/test_consensus_containers.py @@ -127,7 +127,7 @@ def test_signed_attestation_minimal(ssz: SSZTestFiller) -> None: type_name="SignedAttestation", value=SignedAttestation( validator_id=ValidatorIndex(0), - message=_zero_attestation_data(), + data=_zero_attestation_data(), signature=_empty_signature(), ), ) diff --git a/tests/lean_spec/conftest.py b/tests/lean_spec/conftest.py index 20249991..d544f6db 100644 --- a/tests/lean_spec/conftest.py +++ b/tests/lean_spec/conftest.py @@ -71,8 +71,7 @@ def genesis_block(genesis_state: State) -> Block: @pytest.fixture def base_store(genesis_state: State, genesis_block: Block) -> Store: """Fork choice store initialized with genesis.""" - return Store.get_forkchoice_store( - genesis_state, + return genesis_state.to_forkchoice_store( genesis_block, validator_id=ValidatorIndex(0), ) @@ -108,4 +107,4 @@ def keyed_store(keyed_genesis: GenesisData) -> Store: @pytest.fixture def observer_store(keyed_genesis_state: State, keyed_genesis_block: Block) -> Store: """Fork choice store with validator_id=None (non-validator observer).""" - return Store.get_forkchoice_store(keyed_genesis_state, keyed_genesis_block, validator_id=None) + return keyed_genesis_state.to_forkchoice_store(keyed_genesis_block, validator_id=None) diff --git a/tests/lean_spec/helpers/builders.py b/tests/lean_spec/helpers/builders.py index c07becac..edc2deb2 100644 --- a/tests/lean_spec/helpers/builders.py +++ b/tests/lean_spec/helpers/builders.py @@ -30,7 +30,7 @@ from lean_spec.subspecs.containers.block.types import AggregatedAttestations, AttestationSignatures from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state import Validators -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.koalabear import Fp from lean_spec.subspecs.networking import PeerId @@ -240,7 +240,9 @@ def make_aggregated_attestation( ) return AggregatedAttestation( - aggregation_bits=AggregationBits.from_validator_indices(participant_ids), + aggregation_bits=AggregationBits.from_validator_indices( + ValidatorIndices(data=participant_ids) + ), data=data, ) @@ -264,7 +266,7 @@ def make_signed_attestation( ) return SignedAttestation( validator_id=validator, - message=attestation_data, + data=attestation_data, signature=make_mock_signature(), ) @@ -326,7 +328,7 @@ def make_genesis_data( validators = make_validators(num_validators) genesis_state = make_genesis_state(validators=validators, genesis_time=genesis_time) genesis_block = make_genesis_block(genesis_state) - store = Store.get_forkchoice_store(genesis_state, genesis_block, validator_id=validator_id) + store = genesis_state.to_forkchoice_store(genesis_block, validator_id=validator_id) return GenesisData(store, genesis_state, genesis_block) @@ -423,7 +425,7 @@ def make_aggregated_proof( """Create a valid aggregated signature proof for the given participants.""" data_root = attestation_data.data_root_bytes() return AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices(ValidatorIndices(data=participants)), public_keys=[key_manager.get_public_key(vid) for vid in participants], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants diff --git a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py index 83dfb8bd..d50215dc 100644 --- a/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py +++ b/tests/lean_spec/subspecs/containers/test_attestation_aggregation.py @@ -61,7 +61,9 @@ def test_aggregated_attestation_structure(self) -> None: source=Checkpoint(root=Bytes32.zero(), slot=Slot(2)), ) - bits = AggregationBits.from_validator_indices([ValidatorIndex(2), ValidatorIndex(7)]) + bits = AggregationBits.from_validator_indices( + ValidatorIndices(data=[ValidatorIndex(2), ValidatorIndex(7)]) + ) agg = AggregatedAttestation(aggregation_bits=bits, data=att_data) # Verify we can extract validator indices @@ -78,12 +80,12 @@ def test_aggregated_attestation_with_many_validators(self) -> None: source=Checkpoint(root=Bytes32.zero(), slot=Slot(7)), ) - validator_ids = [ValidatorIndex(i) for i in [0, 5, 10, 15, 20, 25]] + validator_ids = ValidatorIndices(data=[ValidatorIndex(i) for i in [0, 5, 10, 15, 20, 25]]) bits = AggregationBits.from_validator_indices(validator_ids) agg = AggregatedAttestation(aggregation_bits=bits, data=att_data) recovered = agg.aggregation_bits.to_validator_indices() - assert recovered == ValidatorIndices(data=validator_ids) + assert recovered == validator_ids class TestAggregateByData: diff --git a/tests/lean_spec/subspecs/containers/test_state_process_attestations.py b/tests/lean_spec/subspecs/containers/test_state_process_attestations.py index 33f9bca4..b03aa08e 100644 --- a/tests/lean_spec/subspecs/containers/test_state_process_attestations.py +++ b/tests/lean_spec/subspecs/containers/test_state_process_attestations.py @@ -50,7 +50,7 @@ HistoricalBlockHashes, JustifiedSlots, ) -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.types import Boolean from tests.lean_spec.helpers import make_bytes32, make_genesis_state @@ -127,7 +127,7 @@ def test_attestation_with_target_beyond_history_is_silently_rejected(self) -> No attestation = AggregatedAttestation( # Two validators participate in this attestation. aggregation_bits=AggregationBits.from_validator_indices( - [ValidatorIndex(0), ValidatorIndex(1)] + ValidatorIndices(data=[ValidatorIndex(0), ValidatorIndex(1)]) ), data=att_data, ) @@ -209,7 +209,7 @@ def test_attestation_with_source_beyond_history_is_silently_rejected(self) -> No attestation = AggregatedAttestation( aggregation_bits=AggregationBits.from_validator_indices( - [ValidatorIndex(0), ValidatorIndex(1)] + ValidatorIndices(data=[ValidatorIndex(0), ValidatorIndex(1)]) ), data=att_data, ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py b/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py index 4a06e3af..4283d038 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py +++ b/tests/lean_spec/subspecs/forkchoice/test_attestation_target.py @@ -157,7 +157,7 @@ def test_safe_target_requires_supermajority( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) # Process as gossip (requires aggregator flag) @@ -202,7 +202,7 @@ def test_safe_target_advances_with_supermajority( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -242,7 +242,7 @@ def test_update_safe_target_uses_new_attestations( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -296,7 +296,7 @@ def test_justification_with_supermajority_attestations( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -378,7 +378,7 @@ def test_justification_tracking_with_multiple_targets( sig = key_manager.sign_attestation_data(vid, attestation_data_head) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data_head, + data=attestation_data_head, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -427,7 +427,7 @@ def test_finalization_after_consecutive_justification( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) store = store.on_gossip_attestation(signed_attestation, is_aggregator=True) @@ -521,7 +521,7 @@ def test_full_attestation_cycle( sig = key_manager.sign_attestation_data(vid, attestation_data) signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=sig, ) # Process as gossip diff --git a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py index e2cd1e28..ff3c1ebb 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py +++ b/tests/lean_spec/subspecs/forkchoice/test_store_attestations.py @@ -16,7 +16,7 @@ ) from lean_spec.subspecs.containers.checkpoint import Checkpoint from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof, SignatureKey from lean_spec.types import Bytes32, Uint64 from tests.lean_spec.helpers import ( @@ -173,7 +173,7 @@ def test_same_subnet_stores_signature(self, key_manager: XmssKeyManager) -> None signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -216,7 +216,7 @@ def test_cross_subnet_ignores_signature(self, key_manager: XmssKeyManager) -> No signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -251,7 +251,7 @@ def test_non_aggregator_never_stores_signature(self, key_manager: XmssKeyManager signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -286,7 +286,7 @@ def test_attestation_data_always_stored(self, key_manager: XmssKeyManager) -> No signed_attestation = SignedAttestation( validator_id=attester_validator, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(attester_validator, attestation_data), ) @@ -336,7 +336,9 @@ def test_valid_proof_stored_correctly(self, key_manager: XmssKeyManager) -> None # Create valid aggregated proof proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants @@ -377,7 +379,9 @@ def test_attestation_data_stored_by_root(self, key_manager: XmssKeyManager) -> N data_root = attestation_data.data_root_bytes() proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants @@ -413,7 +417,9 @@ def test_invalid_proof_rejected(self, key_manager: XmssKeyManager) -> None: # Create proof with WRONG signers (validator 3 signs instead of 2) proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(claimed_participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=claimed_participants) + ), public_keys=[key_manager.get_public_key(vid) for vid in actual_signers], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in actual_signers @@ -446,7 +452,9 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager) -> None: # First proof: validators 1 and 2 participants_1 = [ValidatorIndex(1), ValidatorIndex(2)] proof_1 = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants_1), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants_1) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants_1], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants_1 @@ -458,7 +466,9 @@ def test_multiple_proofs_accumulate(self, key_manager: XmssKeyManager) -> None: # Second proof: validators 1 and 3 (validator 1 overlaps) participants_2 = [ValidatorIndex(1), ValidatorIndex(3)] proof_2 = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants_2), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants_2) + ), public_keys=[key_manager.get_public_key(vid) for vid in participants_2], signatures=[ key_manager.sign_attestation_data(vid, attestation_data) for vid in participants_2 @@ -797,7 +807,7 @@ def test_gossip_to_aggregation_to_storage(self, key_manager: XmssKeyManager) -> for vid in attesting_validators: signed_attestation = SignedAttestation( validator_id=vid, - message=attestation_data, + data=attestation_data, signature=key_manager.sign_attestation_data(vid, attestation_data), ) store = store.on_gossip_attestation( diff --git a/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py b/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py index dd36419d..4a5cacfb 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py +++ b/tests/lean_spec/subspecs/forkchoice/test_store_pruning.py @@ -2,7 +2,7 @@ from lean_spec.subspecs.containers.attestation import AggregationBits from lean_spec.subspecs.containers.slot import Slot -from lean_spec.subspecs.containers.validator import ValidatorIndex +from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices from lean_spec.subspecs.forkchoice import Store from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof, SignatureKey from lean_spec.types import Bytes32 @@ -149,7 +149,9 @@ def test_prunes_related_structures_together(pruning_store: Store) -> None: # Create mock aggregated proof (empty proof data for testing) mock_proof = AggregatedSignatureProof( - participants=AggregationBits.from_validator_indices([ValidatorIndex(1)]), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=[ValidatorIndex(1)]) + ), proof_data=ByteListMiB(data=b""), ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_time_management.py b/tests/lean_spec/subspecs/forkchoice/test_time_management.py index 2adbd852..4811c0f1 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_time_management.py +++ b/tests/lean_spec/subspecs/forkchoice/test_time_management.py @@ -21,14 +21,14 @@ class TestGetForkchoiceStore: - """Test Store.get_forkchoice_store() time initialization.""" + """Test State.to_forkchoice_store() time initialization.""" @settings(max_examples=100) @given(anchor_slot=st.integers(min_value=0, max_value=10000)) def test_store_time_from_anchor_slot(self, anchor_slot: int) -> None: - """get_forkchoice_store sets time = anchor_slot * INTERVALS_PER_SLOT.""" + """to_forkchoice_store sets time = anchor_slot * INTERVALS_PER_SLOT.""" # Must create its own state and block instead of using sample_store() - # because sample_store() bypasses get_forkchoice_store() with hardcoded time. + # because sample_store() bypasses to_forkchoice_store() with hardcoded time. state = State.generate_genesis( genesis_time=Uint64(1000), validators=Validators(data=[]), @@ -43,9 +43,8 @@ def test_store_time_from_anchor_slot(self, anchor_slot: int) -> None: body=make_empty_block_body(), ) - store = Store.get_forkchoice_store( - anchor_state=state, - anchor_block=anchor_block, + store = state.to_forkchoice_store( + anchor_block, validator_id=TEST_VALIDATOR_ID, ) diff --git a/tests/lean_spec/subspecs/forkchoice/test_validator.py b/tests/lean_spec/subspecs/forkchoice/test_validator.py index 57824518..3396eb9a 100644 --- a/tests/lean_spec/subspecs/forkchoice/test_validator.py +++ b/tests/lean_spec/subspecs/forkchoice/test_validator.py @@ -67,7 +67,7 @@ def test_produce_block_with_attestations( ) signed_5 = SignedAttestation( validator_id=ValidatorIndex(5), - message=data_5, + data=data_5, signature=key_manager.sign_attestation_data(ValidatorIndex(5), data_5), ) data_6 = AttestationData( @@ -78,15 +78,15 @@ def test_produce_block_with_attestations( ) signed_6 = SignedAttestation( validator_id=ValidatorIndex(6), - message=data_6, + data=data_6, signature=key_manager.sign_attestation_data(ValidatorIndex(6), data_6), ) - data_root_5 = signed_5.message.data_root_bytes() - data_root_6 = signed_6.message.data_root_bytes() + data_root_5 = signed_5.data.data_root_bytes() + data_root_6 = signed_6.data.data_root_bytes() - proof_5 = make_aggregated_proof(key_manager, [ValidatorIndex(5)], signed_5.message) - proof_6 = make_aggregated_proof(key_manager, [ValidatorIndex(6)], signed_6.message) + proof_5 = make_aggregated_proof(key_manager, [ValidatorIndex(5)], signed_5.data) + proof_6 = make_aggregated_proof(key_manager, [ValidatorIndex(6)], signed_6.data) sig_key_5 = SignatureKey(ValidatorIndex(5), data_root_5) sig_key_6 = SignatureKey(ValidatorIndex(6), data_root_6) @@ -98,8 +98,8 @@ def test_produce_block_with_attestations( sig_key_6: [proof_6], }, "attestation_data_by_root": { - data_root_5: signed_5.message, - data_root_6: signed_6.message, + data_root_5: signed_5.data, + data_root_6: signed_6.data, }, "gossip_signatures": { sig_key_5: signed_5.signature, @@ -216,18 +216,18 @@ def test_produce_block_state_consistency( ) signed_7 = SignedAttestation( validator_id=ValidatorIndex(7), - message=data_7, + data=data_7, signature=key_manager.sign_attestation_data(ValidatorIndex(7), data_7), ) - data_root_7 = signed_7.message.data_root_bytes() - proof_7 = make_aggregated_proof(key_manager, [ValidatorIndex(7)], signed_7.message) + data_root_7 = signed_7.data.data_root_bytes() + proof_7 = make_aggregated_proof(key_manager, [ValidatorIndex(7)], signed_7.data) sig_key_7 = SignatureKey(ValidatorIndex(7), data_root_7) sample_store = sample_store.model_copy( update={ "latest_known_aggregated_payloads": {sig_key_7: [proof_7]}, - "attestation_data_by_root": {data_root_7: signed_7.message}, + "attestation_data_by_root": {data_root_7: signed_7.data}, "gossip_signatures": {sig_key_7: signed_7.signature}, } ) diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py index 84aa2325..03268196 100644 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ b/tests/lean_spec/subspecs/networking/discovery/conftest.py @@ -4,27 +4,34 @@ import pytest +from lean_spec.subspecs.networking.discovery.messages import IdNonce from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes64 +from lean_spec.types import Bytes32, Bytes33, Bytes64 # From devp2p test vectors -NODE_A_PRIVKEY = bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") +NODE_A_PRIVKEY = Bytes32( + bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") +) NODE_A_ID = NodeId( bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") ) -NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") +NODE_B_PRIVKEY = Bytes32( + bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") +) NODE_B_ID = NodeId( bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") ) -NODE_B_PUBKEY = bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") +NODE_B_PUBKEY = Bytes33( + bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") +) # Spec id-nonce used in WHOAREYOU test vectors. -SPEC_ID_NONCE = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") +SPEC_ID_NONCE = IdNonce(bytes.fromhex("0102030405060708090a0b0c0d0e0f10")) @pytest.fixture -def local_private_key() -> bytes: +def local_private_key() -> Bytes32: """Node B's private key from devp2p test vectors.""" return NODE_B_PRIVKEY diff --git a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py index 413c9a47..e7f5cb6b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py @@ -56,16 +56,6 @@ def test_different_ivs_produce_different_ciphertext(self): assert ct1 != ct2 - def test_invalid_key_length_raises(self): - """Test that invalid key length raises ValueError.""" - with pytest.raises(ValueError, match="Key must be 16 bytes"): - aes_ctr_encrypt(bytes(15), bytes(16), b"data") # type: ignore[arg-type] - - def test_invalid_iv_length_raises(self): - """Test that invalid IV length raises ValueError.""" - with pytest.raises(ValueError, match="IV must be 16 bytes"): - aes_ctr_encrypt(bytes(16), bytes(15), b"data") # type: ignore[arg-type] - class TestAesGcm: """Tests for AES-GCM encryption/decryption.""" @@ -105,16 +95,6 @@ def test_wrong_aad_fails_decryption(self): with pytest.raises(InvalidTag): aes_gcm_decrypt(key, nonce, ciphertext, b"wrong aad") - def test_invalid_key_length_raises(self): - """Test that invalid key length raises ValueError.""" - with pytest.raises(ValueError, match="Key must be 16 bytes"): - aes_gcm_encrypt(bytes(15), bytes(12), b"data", b"") # type: ignore[arg-type] - - def test_invalid_nonce_length_raises(self): - """Test that invalid nonce length raises ValueError.""" - with pytest.raises(ValueError, match="Nonce must be 12 bytes"): - aes_gcm_encrypt(bytes(16), bytes(11), b"data", b"") # type: ignore[arg-type] - class TestEcdh: """Tests for secp256k1 ECDH key agreement.""" @@ -292,18 +272,6 @@ def test_zero_private_key_rejected(self): Bytes32.zero(), ) - def test_wrong_length_dest_node_id(self): - """Signing rejects non-32-byte destination node ID.""" - priv, _ = generate_secp256k1_keypair() - _, eph_pub = generate_secp256k1_keypair() - with pytest.raises((ValueError, TypeError)): - sign_id_nonce( - priv, - make_challenge_data(), - eph_pub, - bytes(16), # type: ignore[arg-type] - ) - class TestVerifyIdNonceNegativeCases: """Negative tests for ID nonce signature verification.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py index 63c19809..e63d1fc0 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py @@ -16,7 +16,7 @@ PendingHandshake, ) from lean_spec.subspecs.networking.discovery.keys import compute_node_id -from lean_spec.subspecs.networking.discovery.messages import IdNonce +from lean_spec.subspecs.networking.discovery.messages import IdNonce, Nonce from lean_spec.subspecs.networking.discovery.packet import ( HandshakeAuthdata, WhoAreYouAuthdata, @@ -27,7 +27,7 @@ from lean_spec.subspecs.networking.discovery.session import SessionCache from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes32, Bytes33, Bytes64 +from lean_spec.types import Bytes16, Bytes32, Bytes33, Bytes64 from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY @@ -143,9 +143,9 @@ def test_cancel_nonexistent_returns_false(self, manager): def test_create_whoareyou(self, manager): """Test creating a WHOAREYOU challenge.""" remote_node_id = bytes(32) - request_nonce = bytes(12) + request_nonce = Nonce(bytes(12)) remote_enr_seq = SeqNumber(0) - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id, request_nonce, remote_enr_seq, masking_iv @@ -189,28 +189,6 @@ def test_cleanup_expired(self, manager): assert manager.get_pending(remote1) is None assert manager.get_pending(remote2) is None - def test_invalid_local_node_id_raises(self): - """Test that invalid local node ID raises ValueError.""" - with pytest.raises(ValueError, match="Local node ID must be 32 bytes"): - HandshakeManager( - local_node_id=bytes(31), # type: ignore[arg-type] - local_private_key=bytes(32), - local_enr_rlp=b"enr", - local_enr_seq=SeqNumber(1), - session_cache=SessionCache(), - ) - - def test_invalid_local_private_key_raises(self): - """Test that invalid local private key raises ValueError.""" - with pytest.raises(ValueError, match="Local private key must be 32 bytes"): - HandshakeManager( - local_node_id=NodeId(bytes(32)), - local_private_key=bytes(31), - local_enr_rlp=b"enr", - local_enr_seq=SeqNumber(1), - session_cache=SessionCache(), - ) - class TestHandshakeState: """Tests for HandshakeState enum.""" @@ -270,9 +248,9 @@ def test_create_whoareyou_transitions_to_sent_whoareyou(self, manager): When we receive an undecryptable MESSAGE, we respond with WHOAREYOU. """ remote_node_id = bytes(32) - request_nonce = bytes(12) + request_nonce = Nonce(bytes(12)) remote_enr_seq = SeqNumber(0) - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id, request_nonce, remote_enr_seq, masking_iv @@ -288,9 +266,9 @@ def test_create_whoareyou_transitions_to_sent_whoareyou(self, manager): def test_sent_whoareyou_state_has_challenge_data(self, manager): """In SENT_WHOAREYOU state, all challenge data is stored.""" remote_node_id = bytes(32) - request_nonce = bytes(12) + request_nonce = Nonce(bytes(12)) remote_enr_seq = SeqNumber(5) - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) manager.create_whoareyou(remote_node_id, request_nonce, remote_enr_seq, masking_iv) @@ -330,8 +308,8 @@ def test_handle_handshake_requires_pending_state(self, manager, remote_keypair): src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, ) @@ -350,8 +328,8 @@ def test_handle_handshake_requires_sent_whoareyou_state(self, manager, remote_ke src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, ) @@ -366,9 +344,9 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) # Set up WHOAREYOU state. manager.create_whoareyou( NodeId(remote_node_id), - bytes(12), + Nonce(bytes(12)), SeqNumber(0), - bytes(16), + Bytes16(bytes(16)), ) # Create authdata with different src_id. @@ -377,8 +355,8 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) src_id=wrong_src_id, sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, ) @@ -397,9 +375,9 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa # Set up WHOAREYOU with enr_seq=0 (unknown). manager.create_whoareyou( NodeId(remote_node_id), - bytes(12), + Nonce(bytes(12)), SeqNumber(0), # enr_seq = 0 means we don't know remote's ENR - bytes(16), + Bytes16(bytes(16)), ) # Create authdata without ENR record. @@ -407,8 +385,8 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, - id_signature=bytes(64), - eph_pubkey=bytes(33), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes(33)), record=None, # No ENR included. ) @@ -426,9 +404,9 @@ def test_successful_handshake_with_signature_verification( remote_priv, remote_pub, remote_node_id = remote_keypair # Node A (manager) creates WHOAREYOU for remote. - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - NodeId(remote_node_id), bytes(12), SeqNumber(0), masking_iv + NodeId(remote_node_id), Nonce(bytes(12)), SeqNumber(0), masking_iv ) # Remote creates handshake response. @@ -475,8 +453,8 @@ def test_handle_handshake_rejects_invalid_signature( remote_priv, remote_pub, remote_node_id = remote_keypair # Set up WHOAREYOU state. - masking_iv = bytes(16) - manager.create_whoareyou(NodeId(remote_node_id), bytes(12), SeqNumber(0), masking_iv) + masking_iv = Bytes16(bytes(16)) + manager.create_whoareyou(NodeId(remote_node_id), Nonce(bytes(12)), SeqNumber(0), masking_iv) # Generate ephemeral key. _eph_priv, eph_pub = generate_secp256k1_keypair() @@ -490,7 +468,7 @@ def test_handle_handshake_rejects_invalid_signature( authdata_bytes = encode_handshake_authdata( src_id=NodeId(remote_node_id), - id_signature=bytes(64), # Wrong signature. + id_signature=Bytes64(bytes(64)), # Wrong signature. eph_pubkey=eph_pub, record=remote_enr.to_rlp(), ) @@ -515,7 +493,7 @@ def test_multiple_handshakes_independent(self, manager): manager.start_handshake(remote2) # Create WHOAREYOU for third remote. - manager.create_whoareyou(remote3, bytes(12), SeqNumber(0), bytes(16)) + manager.create_whoareyou(remote3, Nonce(bytes(12)), SeqNumber(0), Bytes16(bytes(16))) # All should have independent state. assert manager.get_pending(remote1).state == HandshakeState.SENT_ORDINARY @@ -581,8 +559,10 @@ def test_id_nonce_uniqueness_across_challenges(self, manager): remote1 = bytes.fromhex("01" + "00" * 31) remote2 = bytes.fromhex("02" + "00" * 31) - id_nonce1, _, _, _ = manager.create_whoareyou(remote1, bytes(12), SeqNumber(0), bytes(16)) - id_nonce2, _, _, _ = manager.create_whoareyou(remote2, bytes(12), SeqNumber(0), bytes(16)) + nonce = Nonce(bytes(12)) + iv = Bytes16(bytes(16)) + id_nonce1, _, _, _ = manager.create_whoareyou(remote1, nonce, SeqNumber(0), iv) + id_nonce2, _, _, _ = manager.create_whoareyou(remote2, nonce, SeqNumber(0), iv) # Each challenge should have unique id_nonce. assert id_nonce1 != id_nonce2 @@ -619,7 +599,7 @@ def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypa authdata, _, _ = manager.create_handshake_response( remote_node_id=NodeId(remote_node_id), whoareyou=whoareyou, - remote_pubkey=bytes(remote_pub), + remote_pubkey=remote_pub, challenge_data=challenge_data, ) @@ -655,7 +635,7 @@ def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_key authdata, _, _ = manager.create_handshake_response( remote_node_id=NodeId(remote_node_id), whoareyou=whoareyou, - remote_pubkey=bytes(remote_pub), + remote_pubkey=remote_pub, challenge_data=challenge_data, ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py index bff0d39c..3bc5816b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_integration.py @@ -21,6 +21,7 @@ from lean_spec.subspecs.networking.discovery.handshake import HandshakeManager from lean_spec.subspecs.networking.discovery.keys import compute_node_id, derive_keys_from_pubkey from lean_spec.subspecs.networking.discovery.messages import ( + Nonce, PacketFlag, Ping, RequestId, @@ -41,7 +42,7 @@ from lean_spec.subspecs.networking.discovery.session import Session, SessionCache from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64 +from lean_spec.types import Bytes12, Bytes16, Bytes64 @pytest.fixture @@ -49,7 +50,7 @@ def node_a_keys(): """Node A's keypair.""" priv, pub = generate_secp256k1_keypair() node_id = compute_node_id(pub) - return {"private_key": priv, "public_key": pub, "node_id": bytes(node_id)} + return {"private_key": priv, "public_key": pub, "node_id": NodeId(node_id)} @pytest.fixture @@ -57,7 +58,7 @@ def node_b_keys(): """Node B's keypair.""" priv, pub = generate_secp256k1_keypair() node_id = compute_node_id(pub) - return {"private_key": priv, "public_key": pub, "node_id": bytes(node_id)} + return {"private_key": priv, "public_key": pub, "node_id": NodeId(node_id)} class TestEncryptedPacketRoundtrip: @@ -67,7 +68,7 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): """MESSAGE packet encrypts and decrypts correctly.""" # Build mock challenge_data for key derivation. # Format: masking-iv (16) + static-header (23) + authdata (24) = 63 bytes. - masking_iv = bytes(16) + masking_iv = Bytes16(bytes(16)) static_header = b"discv5" + b"\x00\x01\x01" + bytes(12) + b"\x00\x18" authdata = bytes(24) challenge_data = masking_iv + static_header + authdata @@ -75,10 +76,10 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): # Create session keys (derived from ECDH). # Node A is initiator. send_key, recv_key = derive_keys_from_pubkey( - local_private_key=Bytes32(node_a_keys["private_key"]), + local_private_key=node_a_keys["private_key"], remote_public_key=node_b_keys["public_key"], - local_node_id=Bytes32(node_a_keys["node_id"]), - remote_node_id=Bytes32(node_b_keys["node_id"]), + local_node_id=node_a_keys["node_id"], + remote_node_id=node_b_keys["node_id"], challenge_data=challenge_data, is_initiator=True, ) @@ -98,7 +99,7 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): packet = encode_packet( dest_node_id=node_b_keys["node_id"], flag=PacketFlag.MESSAGE, - nonce=bytes(nonce), + nonce=nonce, authdata=authdata, message=message_bytes, encryption_key=send_key, @@ -115,10 +116,10 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): # Node B derives keys as recipient (using same challenge_data). b_send_key, b_recv_key = derive_keys_from_pubkey( - local_private_key=Bytes32(node_b_keys["private_key"]), + local_private_key=node_b_keys["private_key"], remote_public_key=node_a_keys["public_key"], - local_node_id=Bytes32(node_b_keys["node_id"]), - remote_node_id=Bytes32(node_a_keys["node_id"]), + local_node_id=node_b_keys["node_id"], + remote_node_id=node_a_keys["node_id"], challenge_data=challenge_data, is_initiator=False, ) @@ -145,8 +146,8 @@ def test_session_cache_operations(self, node_a_keys, node_b_keys): now = time.time() session = Session( node_id=node_b_keys["node_id"], - send_key=bytes(16), - recv_key=bytes(16), + send_key=Bytes16(bytes(16)), + recv_key=Bytes16(bytes(16)), created_at=now, last_seen=now, is_initiator=True, @@ -231,8 +232,8 @@ def test_whoareyou_generation(self, node_a_keys, node_b_keys): ) # Create WHOAREYOU. - request_nonce = bytes(12) - masking_iv = bytes(16) + request_nonce = Nonce(bytes(12)) + masking_iv = Bytes16(bytes(16)) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( remote_node_id=node_b_keys["node_id"], request_nonce=request_nonce, @@ -327,8 +328,8 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): manager_a.start_handshake(node_b_keys["node_id"]) # Step 2: Node B creates WHOAREYOU. - request_nonce = bytes(12) - masking_iv = bytes(16) + request_nonce = Nonce(bytes(12)) + masking_iv = Bytes16(bytes(16)) id_nonce, whoareyou_authdata, _, challenge_data = manager_b.create_whoareyou( remote_node_id=node_a_keys["node_id"], request_nonce=request_nonce, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_keys.py b/tests/lean_spec/subspecs/networking/discovery/test_keys.py index b30bc151..038fd31a 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_keys.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_keys.py @@ -1,7 +1,5 @@ """Tests for Discovery v5 key derivation.""" -import pytest - from lean_spec.subspecs.networking.discovery.crypto import ( generate_secp256k1_keypair, pubkey_to_uncompressed, @@ -81,21 +79,6 @@ def test_order_matters(self): assert keys_ab != keys_ba - def test_invalid_secret_length_raises(self): - """Test that invalid secret length raises ValueError.""" - with pytest.raises(ValueError, match="Secret must be 33 bytes"): - derive_keys(bytes(32), bytes(32), bytes(32), make_challenge_data()) # type: ignore[arg-type] - - def test_invalid_initiator_id_length_raises(self): - """Test that invalid initiator ID length raises ValueError.""" - with pytest.raises(ValueError, match="Initiator ID must be 32 bytes"): - derive_keys(bytes(33), bytes(31), bytes(32), make_challenge_data()) # type: ignore[arg-type] - - def test_invalid_recipient_id_length_raises(self): - """Test that invalid recipient ID length raises ValueError.""" - with pytest.raises(ValueError, match="Recipient ID must be 32 bytes"): - derive_keys(bytes(33), bytes(32), bytes(31), make_challenge_data()) # type: ignore[arg-type] - class TestDeriveKeysFromPubkey: """Tests for key derivation from ECDH.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_packet.py b/tests/lean_spec/subspecs/networking/discovery/test_packet.py index 7ea230dc..30495bae 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_packet.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_packet.py @@ -4,7 +4,7 @@ from lean_spec.subspecs.networking.discovery.config import MAX_PACKET_SIZE, MIN_PACKET_SIZE from lean_spec.subspecs.networking.discovery.crypto import aes_ctr_encrypt -from lean_spec.subspecs.networking.discovery.messages import PacketFlag +from lean_spec.subspecs.networking.discovery.messages import IdNonce, Nonce, PacketFlag from lean_spec.subspecs.networking.discovery.packet import ( HANDSHAKE_HEADER_SIZE, MESSAGE_AUTHDATA_SIZE, @@ -22,7 +22,7 @@ generate_nonce, ) from lean_spec.subspecs.networking.types import NodeId, SeqNumber -from lean_spec.types import Bytes16 +from lean_spec.types import Bytes16, Bytes33, Bytes64 class TestNonceGeneration: @@ -75,7 +75,7 @@ class TestWhoAreYouAuthdata: def test_encode_whoareyou_authdata(self): """Test WHOAREYOU authdata encoding.""" - id_nonce = bytes(16) + id_nonce = IdNonce(bytes(16)) enr_seq = SeqNumber(42) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -84,24 +84,24 @@ def test_encode_whoareyou_authdata(self): def test_decode_whoareyou_authdata(self): """Test WHOAREYOU authdata decoding.""" - id_nonce = bytes.fromhex("aa" * 16) + id_nonce = IdNonce(bytes.fromhex("aa" * 16)) enr_seq = SeqNumber(12345) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce + assert decoded.id_nonce == id_nonce assert decoded.enr_seq == enr_seq def test_roundtrip(self): """Test encoding then decoding preserves values.""" - id_nonce = bytes.fromhex("01" * 16) + id_nonce = IdNonce(bytes.fromhex("01" * 16)) enr_seq = SeqNumber(2**63 - 1) # Max uint64 authdata = encode_whoareyou_authdata(id_nonce, enr_seq) decoded = decode_whoareyou_authdata(authdata) - assert bytes(decoded.id_nonce) == id_nonce + assert decoded.id_nonce == id_nonce assert decoded.enr_seq == enr_seq def test_invalid_size_raises(self): @@ -116,8 +116,8 @@ class TestHandshakeAuthdata: def test_encode_handshake_authdata(self): """Test HANDSHAKE authdata encoding.""" src_id = NodeId(bytes(32)) - id_signature = bytes(64) - eph_pubkey = bytes([0x02]) + bytes(32) # Compressed pubkey format + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33(bytes([0x02]) + bytes(32)) authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey) @@ -128,8 +128,8 @@ def test_encode_handshake_authdata(self): def test_decode_handshake_authdata(self): """Test HANDSHAKE authdata decoding.""" src_id = NodeId(bytes.fromhex("aa" * 32)) - id_signature = bytes.fromhex("bb" * 64) - eph_pubkey = bytes([0x02]) + bytes.fromhex("cc" * 32) + id_signature = Bytes64(bytes.fromhex("bb" * 64)) + eph_pubkey = Bytes33(bytes([0x02]) + bytes.fromhex("cc" * 32)) authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey) decoded = decode_handshake_authdata(authdata) @@ -144,8 +144,8 @@ def test_decode_handshake_authdata(self): def test_with_enr_record(self): """Test HANDSHAKE authdata with ENR record.""" src_id = NodeId(bytes(32)) - id_signature = bytes(64) - eph_pubkey = bytes([0x02]) + bytes(32) + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33(bytes([0x02]) + bytes(32)) record = b"enr:-IS4QHCYrY..." # Mock ENR authdata = encode_handshake_authdata(src_id, id_signature, eph_pubkey, record) @@ -153,25 +153,6 @@ def test_with_enr_record(self): assert decoded.record == record - def test_invalid_src_id_length_raises(self): - """Test that invalid src_id length raises ValueError.""" - with pytest.raises(ValueError, match="Source ID must be 32 bytes"): - encode_handshake_authdata( - bytes(31), # type: ignore[arg-type] - bytes(64), - bytes(33), - ) - - def test_invalid_signature_length_raises(self): - """Test that invalid signature length raises ValueError.""" - with pytest.raises(ValueError, match="Signature must be 64 bytes"): - encode_handshake_authdata(NodeId(bytes(32)), bytes(63), bytes(33)) - - def test_invalid_eph_pubkey_length_raises(self): - """Test that invalid ephemeral pubkey length raises ValueError.""" - with pytest.raises(ValueError, match="Ephemeral pubkey must be 33 bytes"): - encode_handshake_authdata(NodeId(bytes(32)), bytes(64), bytes(32)) - class TestPacketEncoding: """Tests for full packet encoding/decoding.""" @@ -180,10 +161,10 @@ def test_encode_message_packet(self): """Test MESSAGE packet encoding.""" dest_node_id = NodeId(bytes(32)) src_node_id = NodeId(bytes(32)) - nonce = bytes(12) + nonce = Nonce(bytes(12)) authdata = encode_message_authdata(src_node_id) message = b"encrypted message" - encryption_key = bytes(16) + encryption_key = Bytes16(bytes(16)) packet = encode_packet( dest_node_id=dest_node_id, @@ -200,8 +181,8 @@ def test_encode_message_packet(self): def test_encode_whoareyou_packet(self): """Test WHOAREYOU packet encoding.""" dest_node_id = NodeId(bytes(32)) - nonce = bytes(12) - id_nonce = bytes(16) + nonce = Nonce(bytes(12)) + id_nonce = IdNonce(bytes(16)) authdata = encode_whoareyou_authdata(id_nonce, SeqNumber(0)) packet = encode_packet( @@ -210,7 +191,7 @@ def test_encode_whoareyou_packet(self): nonce=nonce, authdata=authdata, message=b"", - encryption_key=None, # WHOAREYOU doesn't encrypt + encryption_key=None, ) # WHOAREYOU has no message content @@ -220,8 +201,8 @@ def test_encode_whoareyou_packet(self): def test_decode_packet_header(self): """Test packet header decoding.""" local_node_id = NodeId(bytes(32)) - nonce = bytes(12) - authdata = encode_whoareyou_authdata(bytes(16), SeqNumber(42)) + nonce = Nonce(bytes(12)) + authdata = encode_whoareyou_authdata(IdNonce(bytes(16)), SeqNumber(42)) packet = encode_packet( dest_node_id=local_node_id, @@ -235,34 +216,10 @@ def test_decode_packet_header(self): header, message_bytes, _message_ad = decode_packet_header(local_node_id, packet) assert header.flag == PacketFlag.WHOAREYOU - assert bytes(header.nonce) == nonce + assert header.nonce == nonce assert header.authdata == authdata assert message_bytes == b"" - def test_invalid_dest_node_id_length_raises(self): - """Test that invalid dest_node_id length raises ValueError.""" - with pytest.raises(ValueError, match="Destination node ID must be 32 bytes"): - encode_packet( - dest_node_id=bytes(31), # type: ignore[arg-type] - flag=PacketFlag.MESSAGE, - nonce=bytes(12), - authdata=bytes(32), - message=b"", - encryption_key=bytes(16), - ) - - def test_invalid_nonce_length_raises(self): - """Test that invalid nonce length raises ValueError.""" - with pytest.raises(ValueError, match="Nonce must be 12 bytes"): - encode_packet( - dest_node_id=NodeId(bytes(32)), - flag=PacketFlag.MESSAGE, - nonce=bytes(11), - authdata=bytes(32), - message=b"", - encryption_key=bytes(16), - ) - class TestConstants: """Tests for packet format constants.""" @@ -327,8 +284,8 @@ def test_encode_packet_enforces_max_size(self): """encode_packet raises error if packet exceeds max size.""" src_id = NodeId(bytes(32)) dest_id = NodeId(bytes(32)) - nonce = bytes(12) - encryption_key = bytes(16) + nonce = Nonce(bytes(12)) + encryption_key = Bytes16(bytes(16)) # Create authdata. authdata = encode_message_authdata(src_id) @@ -391,7 +348,7 @@ def test_message_flag_without_encryption_key_raises(self): encode_packet( dest_node_id=NodeId(bytes(32)), flag=PacketFlag.MESSAGE, - nonce=bytes(12), + nonce=Nonce(bytes(12)), authdata=encode_message_authdata(NodeId(bytes(32))), message=b"\x01\xc2\x01\x01", encryption_key=None, @@ -401,35 +358,21 @@ def test_handshake_flag_without_encryption_key_raises(self): """HANDSHAKE packets require an encryption key.""" authdata = encode_handshake_authdata( src_id=NodeId(bytes(32)), - id_signature=bytes(64), - eph_pubkey=bytes([0x02]) + bytes(32), + id_signature=Bytes64(bytes(64)), + eph_pubkey=Bytes33(bytes([0x02]) + bytes(32)), ) with pytest.raises(ValueError, match="Encryption key required"): encode_packet( dest_node_id=NodeId(bytes(32)), flag=PacketFlag.HANDSHAKE, - nonce=bytes(12), + nonce=Nonce(bytes(12)), authdata=authdata, message=b"\x01\xc2\x01\x01", encryption_key=None, ) -class TestAuthdataInvalidLengths: - """Edge cases for authdata encoding with invalid input lengths.""" - - def test_encode_whoareyou_authdata_wrong_id_nonce_length(self): - """WHOAREYOU authdata rejects id_nonce that is not 16 bytes.""" - with pytest.raises(ValueError, match="ID nonce must be 16 bytes"): - encode_whoareyou_authdata(bytes(15), SeqNumber(0)) - - def test_encode_message_authdata_wrong_src_id_length(self): - """MESSAGE authdata rejects src_id that is not 32 bytes.""" - with pytest.raises(ValueError, match="Source ID must be 32 bytes"): - encode_message_authdata(bytes(31)) # type: ignore[arg-type] - - class TestPacketProtocolValidation: """Protocol ID and version validation in packet decoding.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_routing.py b/tests/lean_spec/subspecs/networking/discovery/test_routing.py index dff72956..1ea5578b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_routing.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_routing.py @@ -19,9 +19,8 @@ ) from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.enr.eth2 import FAR_FUTURE_EPOCH -from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber from lean_spec.types import Bytes64 -from lean_spec.types.byte_arrays import Bytes4 class TestXorDistance: @@ -447,7 +446,7 @@ def test_fork_filter_rejects_without_enr(self, local_node_id, remote_node_id): """With fork filter, nodes without ENR are rejected.""" table = RoutingTable( local_id=local_node_id, - local_fork_digest=Bytes4(bytes(4)), + local_fork_digest=ForkDigest(bytes(4)), ) entry = NodeEntry(node_id=remote_node_id, enr=None) @@ -457,7 +456,7 @@ def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_ """Nodes without eth2 data are rejected when filtering.""" table = RoutingTable( local_id=local_node_id, - local_fork_digest=Bytes4(bytes(4)), + local_fork_digest=ForkDigest(bytes(4)), ) enr = ENR( @@ -472,7 +471,7 @@ def test_fork_filter_rejects_without_eth2_data(self, local_node_id, remote_node_ def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id): """Node with different fork_digest is rejected.""" - local_fork = Bytes4(bytes.fromhex("12345678")) + local_fork = ForkDigest(bytes.fromhex("12345678")) table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) # Build eth2 bytes with a different fork digest. @@ -491,7 +490,7 @@ def test_fork_filter_rejects_mismatched_fork(self, local_node_id, remote_node_id def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): """Node with matching fork_digest is accepted.""" - local_fork = Bytes4(bytes.fromhex("12345678")) + local_fork = ForkDigest(bytes.fromhex("12345678")) table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) # Build eth2 bytes with the same fork digest. @@ -513,7 +512,7 @@ def test_fork_filter_accepts_matching_fork(self, local_node_id, remote_node_id): def test_is_fork_compatible_method(self, local_node_id): """Verify is_fork_compatible for compatible, incompatible, and no-ENR entries.""" - local_fork = Bytes4(bytes.fromhex("12345678")) + local_fork = ForkDigest(bytes.fromhex("12345678")) table = RoutingTable(local_id=local_node_id, local_fork_digest=local_fork) # Compatible entry. diff --git a/tests/lean_spec/subspecs/networking/discovery/test_session.py b/tests/lean_spec/subspecs/networking/discovery/test_session.py index bda15a64..9e4fb3fb 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_session.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_session.py @@ -2,8 +2,6 @@ import time -import pytest - from lean_spec.subspecs.networking.discovery.messages import Port from lean_spec.subspecs.networking.discovery.session import ( BondCache, @@ -11,6 +9,9 @@ SessionCache, ) from lean_spec.subspecs.networking.types import NodeId +from lean_spec.types import Bytes16 + +ZERO_KEY = Bytes16(bytes(16)) class TestSession: @@ -20,8 +21,8 @@ def test_create_session(self): """Test session creation.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time(), last_seen=time.time(), is_initiator=True, @@ -35,8 +36,8 @@ def test_is_expired_false_for_new_session(self): """Test that new session is not expired.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time(), last_seen=time.time(), is_initiator=True, @@ -48,8 +49,8 @@ def test_is_expired_true_for_old_session(self): """Test that old session is expired.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time() - 7200, # 2 hours ago last_seen=time.time() - 7200, is_initiator=True, @@ -61,8 +62,8 @@ def test_touch_updates_last_seen(self): """Test that touch updates last_seen timestamp.""" session = Session( node_id=NodeId(bytes(32)), - send_key=bytes(16), - recv_key=bytes(16), + send_key=ZERO_KEY, + recv_key=ZERO_KEY, created_at=time.time() - 100, last_seen=time.time() - 100, is_initiator=True, @@ -81,8 +82,8 @@ def test_create_and_get_session(self): """Test creating and retrieving a session.""" cache = SessionCache() node_id = NodeId(bytes.fromhex("aa" * 32)) - send_key = bytes(16) - recv_key = bytes(16) + send_key = ZERO_KEY + recv_key = ZERO_KEY session = cache.create(node_id, send_key, recv_key, is_initiator=True) @@ -101,7 +102,7 @@ def test_get_expired_returns_none(self): cache = SessionCache(timeout_secs=0.001) node_id = NodeId(bytes(32)) - cache.create(node_id, bytes(16), bytes(16), is_initiator=True) + cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) time.sleep(0.01) assert cache.get(node_id) is None @@ -112,7 +113,7 @@ def test_remove_session(self): cache = SessionCache() node_id = NodeId(bytes(32)) - cache.create(node_id, bytes(16), bytes(16), is_initiator=True) + cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.remove(node_id) assert cache.get(node_id) is None @@ -126,7 +127,7 @@ def test_touch_updates_session(self): cache = SessionCache() node_id = NodeId(bytes(32)) - cache.create(node_id, bytes(16), bytes(16), is_initiator=True) + cache.create(node_id, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.touch(node_id) def test_touch_nonexistent_returns_false(self): @@ -140,18 +141,18 @@ def test_count(self): assert cache.count() == 0 - cache.create(NodeId(bytes.fromhex("aa" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("aa" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 1 - cache.create(NodeId(bytes.fromhex("bb" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("bb" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 2 def test_cleanup_expired(self): """Test expired session cleanup.""" cache = SessionCache(timeout_secs=0.001) - cache.create(NodeId(bytes.fromhex("aa" * 32)), bytes(16), bytes(16), is_initiator=True) - cache.create(NodeId(bytes.fromhex("bb" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("aa" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) + cache.create(NodeId(bytes.fromhex("bb" * 32)), ZERO_KEY, ZERO_KEY, is_initiator=True) time.sleep(0.01) removed = cache.cleanup_expired() @@ -166,36 +167,20 @@ def test_eviction_when_full(self): node2 = NodeId(bytes.fromhex("02" + "00" * 31)) node3 = NodeId(bytes.fromhex("03" + "00" * 31)) - cache.create(node1, bytes(16), bytes(16), is_initiator=True) + cache.create(node1, ZERO_KEY, ZERO_KEY, is_initiator=True) time.sleep(0.01) # Ensure different timestamps - cache.create(node2, bytes(16), bytes(16), is_initiator=True) + cache.create(node2, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 2 # Adding third should evict first - cache.create(node3, bytes(16), bytes(16), is_initiator=True) + cache.create(node3, ZERO_KEY, ZERO_KEY, is_initiator=True) assert cache.count() == 2 assert cache.get(node1) is None # Oldest should be evicted assert cache.get(node2) is not None assert cache.get(node3) is not None - def test_invalid_node_id_length_raises(self): - """Test that invalid node ID length raises ValueError.""" - cache = SessionCache() - with pytest.raises(ValueError, match="Node ID must be 32 bytes"): - cache.create(bytes(31), bytes(16), bytes(16), is_initiator=True) # type: ignore[arg-type] - - def test_invalid_key_length_raises(self): - """Test that invalid key lengths raise ValueError.""" - cache = SessionCache() - - with pytest.raises(ValueError, match="Send key must be 16 bytes"): - cache.create(NodeId(bytes(32)), bytes(15), bytes(16), is_initiator=True) - - with pytest.raises(ValueError, match="Recv key must be 16 bytes"): - cache.create(NodeId(bytes(32)), bytes(16), bytes(15), is_initiator=True) - def test_endpoint_keying_separates_sessions(self): """Same node_id at different ip:port has separate sessions. @@ -204,15 +189,15 @@ def test_endpoint_keying_separates_sessions(self): """ cache = SessionCache() node_id = NodeId(bytes.fromhex("aa" * 32)) - send_key_1 = bytes([0x01] * 16) - send_key_2 = bytes([0x02] * 16) + send_key_1 = Bytes16(bytes([0x01] * 16)) + send_key_2 = Bytes16(bytes([0x02] * 16)) # Create sessions for same node at different endpoints. cache.create( - node_id, send_key_1, bytes(16), is_initiator=True, ip="10.0.0.1", port=Port(9000) + node_id, send_key_1, ZERO_KEY, is_initiator=True, ip="10.0.0.1", port=Port(9000) ) cache.create( - node_id, send_key_2, bytes(16), is_initiator=True, ip="10.0.0.2", port=Port(9000) + node_id, send_key_2, ZERO_KEY, is_initiator=True, ip="10.0.0.2", port=Port(9000) ) # Each endpoint retrieves its own session. diff --git a/tests/lean_spec/subspecs/networking/discovery/test_transport.py b/tests/lean_spec/subspecs/networking/discovery/test_transport.py index 95f6a613..223f29f5 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_transport.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_transport.py @@ -279,7 +279,7 @@ async def test_stop_cancels_pending_requests(self, local_node_id, local_private_ request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -318,7 +318,7 @@ def test_create_pending_request(self): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -410,7 +410,7 @@ def test_pending_multi_request_creation(self, local_node_id, local_private_key, request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=None, @@ -433,7 +433,7 @@ def test_pending_multi_request_expected_total_tracking(self): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=None, @@ -469,7 +469,7 @@ async def test_queue(): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=3, @@ -575,7 +575,7 @@ def test_pending_request_stores_request_id(self): request_id=b"\x01\x02\x03\x04", dest_node_id=NodeId(bytes(32)), sent_at=123.456, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -599,7 +599,7 @@ async def test_future(): request_id=b"\x01", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -635,7 +635,7 @@ def test_pending_request_future_cancellation(self): request_id=b"\x01", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -666,7 +666,7 @@ def test_request_id_bytes_for_dict_lookup(self): request_id=request_id_1, dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message1, future=future1, ) @@ -675,7 +675,7 @@ def test_request_id_bytes_for_dict_lookup(self): request_id=request_id_2, dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message2, future=future2, ) @@ -699,7 +699,7 @@ def test_pending_request_stores_nonce_for_whoareyou_matching(self): loop = asyncio.new_event_loop() future: asyncio.Future = loop.create_future() - nonce = b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c" + nonce = Nonce(b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c") message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1)) pending = PendingRequest( @@ -727,7 +727,7 @@ def test_pending_request_stores_message_for_retransmission(self): request_id=b"\x01", dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=message, future=future, ) @@ -785,7 +785,7 @@ async def test_pending_requests_cleared_on_stop( request_id=bytes([i]), dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -828,7 +828,7 @@ async def test_pending_request_futures_cancelled_on_stop( request_id=bytes([i]), dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -1092,7 +1092,7 @@ async def test_response_completes_pending_request_future( request_id=request_id, dest_node_id=remote_node_id, sent_at=loop.time(), - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), future=future, ) @@ -1128,7 +1128,7 @@ async def test_response_enqueued_for_multi_request( request_id=request_id, dest_node_id=remote_node_id, sent_at=0.0, - nonce=bytes(12), + nonce=Nonce(bytes(12)), message=MagicMock(), response_queue=queue, expected_total=None, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py index 6b9c9a20..51f34e6d 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py @@ -29,9 +29,11 @@ from lean_spec.subspecs.networking.discovery.messages import ( Distance, FindNode, + IdNonce, IPv4, MessageType, Nodes, + Nonce, PacketFlag, Ping, Pong, @@ -53,7 +55,7 @@ encode_whoareyou_authdata, ) from lean_spec.subspecs.networking.discovery.routing import log2_distance, xor_distance -from lean_spec.subspecs.networking.types import NodeId, SeqNumber +from lean_spec.subspecs.networking.types import SeqNumber from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes33, Bytes64 from lean_spec.types.uint import Uint8 from tests.lean_spec.helpers import make_challenge_data @@ -67,27 +69,27 @@ ) # Spec test vector values for ECDH and key derivation. -SPEC_NONCE = bytes.fromhex("0102030405060708090a0b0c") +SPEC_NONCE = Nonce(bytes.fromhex("0102030405060708090a0b0c")) SPEC_CHALLENGE_DATA = bytes.fromhex( "000000000000000000000000000000006469736376350001010102030405060708090a0b0c" "00180102030405060708090a0b0c0d0e0f100000000000000000" ) # Spec ephemeral keypair for ECDH / ID nonce signing. -SPEC_EPHEMERAL_KEY = bytes.fromhex( - "fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736" +SPEC_EPHEMERAL_KEY = Bytes32( + bytes.fromhex("fb757dc581730490a1d7a00deea65e9b1936924caaea8f44d476014856b68736") ) -SPEC_EPHEMERAL_PUBKEY = bytes.fromhex( - "039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231" +SPEC_EPHEMERAL_PUBKEY = Bytes33( + bytes.fromhex("039961e4c2356d61bedb83052c115d311acb3a96f5777296dcf297351130266231") ) # Derived session keys from spec HKDF test vector. -SPEC_INITIATOR_KEY = bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571") -SPEC_RECIPIENT_KEY = bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5") +SPEC_INITIATOR_KEY = Bytes16(bytes.fromhex("dccc82d81bd610f4f76d3ebe97a40571")) +SPEC_RECIPIENT_KEY = Bytes16(bytes.fromhex("ac74bb8773749920b0d3a8881c173ec5")) # AES-GCM test vector values. -SPEC_AES_KEY = bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b") -SPEC_AES_NONCE = bytes.fromhex("27b5af763c446acd2749fe8e") +SPEC_AES_KEY = Bytes16(bytes.fromhex("9f2d77db7004bf8a1a85107ac686990b")) +SPEC_AES_NONCE = Bytes12(bytes.fromhex("27b5af763c446acd2749fe8e")) # PING message plaintext (type 0x01, RLP [1]). SPEC_PING_PLAINTEXT = bytes.fromhex("01c20101") @@ -108,9 +110,11 @@ def test_node_b_id_from_privkey(self): int.from_bytes(NODE_B_PRIVKEY, "big"), ec.SECP256K1(), ) - pubkey_bytes = private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + pubkey_bytes = Bytes33( + private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) computed = compute_node_id(pubkey_bytes) @@ -126,9 +130,11 @@ def test_node_a_id_from_privkey(self): int.from_bytes(NODE_A_PRIVKEY, "big"), ec.SECP256K1(), ) - pubkey_bytes = private_key.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + pubkey_bytes = Bytes33( + private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) computed = compute_node_id(pubkey_bytes) assert bytes(computed) == NODE_A_ID @@ -144,13 +150,15 @@ def test_bidirectional_ecdh(self): int.from_bytes(NODE_A_PRIVKEY, "big"), ec.SECP256K1(), ) - a_pubkey_bytes = a_privkey.public_key().public_bytes( - encoding=serialization.Encoding.X962, - format=serialization.PublicFormat.CompressedPoint, + a_pubkey_bytes = Bytes33( + a_privkey.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) ) - shared_ab = ecdh_agree(Bytes32(NODE_A_PRIVKEY), NODE_B_PUBKEY) - shared_ba = ecdh_agree(Bytes32(NODE_B_PRIVKEY), a_pubkey_bytes) + shared_ab = ecdh_agree(NODE_A_PRIVKEY, NODE_B_PUBKEY) + shared_ba = ecdh_agree(NODE_B_PRIVKEY, a_pubkey_bytes) assert shared_ab == shared_ba @@ -168,7 +176,7 @@ def test_ecdh_shared_secret(self): "033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e" ) - shared = ecdh_agree(Bytes32(SPEC_EPHEMERAL_KEY), SPEC_EPHEMERAL_PUBKEY) + shared = ecdh_agree(SPEC_EPHEMERAL_KEY, SPEC_EPHEMERAL_PUBKEY) assert shared == expected_shared @@ -180,13 +188,13 @@ def test_key_derivation_hkdf(self): Uses exact spec challenge_data (with nonce 0102030405060708090a0b0c). """ # Compute ECDH shared secret. - shared_secret = ecdh_agree(Bytes32(SPEC_EPHEMERAL_KEY), NODE_B_PUBKEY) + shared_secret = ecdh_agree(SPEC_EPHEMERAL_KEY, NODE_B_PUBKEY) # Derive keys using exact spec challenge_data. initiator_key, recipient_key = derive_keys( secret=shared_secret, - initiator_id=Bytes32(NODE_A_ID), - recipient_id=Bytes32(NODE_B_ID), + initiator_id=NODE_A_ID, + recipient_id=NODE_B_ID, challenge_data=SPEC_CHALLENGE_DATA, ) @@ -206,10 +214,10 @@ def test_id_nonce_signature(self): """ # Sign using exact spec challenge_data. signature = sign_id_nonce( - private_key_bytes=Bytes32(SPEC_EPHEMERAL_KEY), + private_key_bytes=SPEC_EPHEMERAL_KEY, challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=Bytes33(SPEC_EPHEMERAL_PUBKEY), - dest_node_id=Bytes32(NODE_B_ID), + ephemeral_pubkey=SPEC_EPHEMERAL_PUBKEY, + dest_node_id=NODE_B_ID, ) expected_sig = bytes.fromhex( @@ -231,8 +239,8 @@ def test_id_nonce_signature(self): assert verify_id_nonce_signature( signature=Bytes64(signature), challenge_data=SPEC_CHALLENGE_DATA, - ephemeral_pubkey=Bytes33(SPEC_EPHEMERAL_PUBKEY), - dest_node_id=Bytes32(NODE_B_ID), + ephemeral_pubkey=SPEC_EPHEMERAL_PUBKEY, + dest_node_id=NODE_B_ID, public_key_bytes=Bytes33(pubkey_bytes), ) @@ -242,16 +250,16 @@ def test_id_nonce_signature_different_challenge_data(self): challenge_data2 = make_challenge_data(bytes([1]) + bytes(15)) sig1 = sign_id_nonce( - Bytes32(NODE_B_PRIVKEY), + NODE_B_PRIVKEY, challenge_data1, - Bytes33(SPEC_EPHEMERAL_PUBKEY), - Bytes32(NODE_A_ID), + SPEC_EPHEMERAL_PUBKEY, + NODE_A_ID, ) sig2 = sign_id_nonce( - Bytes32(NODE_B_PRIVKEY), + NODE_B_PRIVKEY, challenge_data2, - Bytes33(SPEC_EPHEMERAL_PUBKEY), - Bytes32(NODE_A_ID), + SPEC_EPHEMERAL_PUBKEY, + NODE_A_ID, ) assert sig1 != sig2 @@ -266,14 +274,12 @@ def test_aes_gcm_encryption(self): expected_ciphertext = bytes.fromhex("a5d12a2d94b8ccb3ba55558229867dc13bfa3648") # Encrypt. - ciphertext = aes_gcm_encrypt( - Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), SPEC_PING_PLAINTEXT, aad - ) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, SPEC_PING_PLAINTEXT, aad) assert ciphertext == expected_ciphertext # Verify decryption works. - decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) assert decrypted == SPEC_PING_PLAINTEXT @@ -302,8 +308,8 @@ def test_decode_spec_ping_packet(self): assert decoded_authdata.src_id == NODE_A_ID # Decrypt using the spec's read-key (all zeros for this test vector). - read_key = bytes(16) - plaintext = decrypt_message(read_key, bytes(header.nonce), ciphertext, message_ad) + read_key = Bytes16(bytes(16)) + plaintext = decrypt_message(read_key, header.nonce, ciphertext, message_ad) # PING with request-id=0x00000001 (4 bytes) and enr-seq=2. decoded = decode_message(plaintext) @@ -386,8 +392,8 @@ def test_decode_spec_handshake_with_enr_packet(self): assert len(decoded_authdata.record) > 0 # Decrypt the message using the spec's read-key. - read_key = bytes.fromhex("53b1c075f41876423154e157470c2f48") - plaintext = decrypt_message(read_key, bytes(header.nonce), ciphertext, message_ad) + read_key = Bytes16(bytes.fromhex("53b1c075f41876423154e157470c2f48")) + plaintext = decrypt_message(read_key, header.nonce, ciphertext, message_ad) # PING with request-id=0x00000001 and enr-seq=1. decoded = decode_message(plaintext) @@ -400,8 +406,8 @@ class TestPacketEncodingRoundtrip: def test_message_packet_roundtrip(self): """MESSAGE packet encodes and decodes correctly.""" - nonce = bytes(12) # 12-byte nonce. - encryption_key = bytes(16) # 16-byte key. + nonce = Nonce(bytes(12)) # 12-byte nonce. + encryption_key = Bytes16(bytes(16)) # 16-byte key. message = b"\x01\xc2\x01\x01" # PING message. authdata = encode_message_authdata(NODE_A_ID) @@ -426,8 +432,8 @@ def test_message_packet_roundtrip(self): def test_whoareyou_packet_roundtrip(self): """WHOAREYOU packet encodes and decodes correctly.""" - nonce = bytes.fromhex("0102030405060708090a0b0c") - id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") + nonce = Nonce(bytes.fromhex("0102030405060708090a0b0c")) + id_nonce = IdNonce(bytes.fromhex("0102030405060708090a0b0c0d0e0f10")) enr_seq = SeqNumber(0) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -453,12 +459,12 @@ def test_whoareyou_packet_roundtrip(self): def test_handshake_packet_roundtrip(self): """HANDSHAKE packet encodes and decodes correctly.""" - nonce = bytes(12) + nonce = Nonce(bytes(12)) message = b"\x01\xc2\x01\x01" # PING message. - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33( + bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") ) authdata = encode_handshake_authdata( @@ -587,8 +593,8 @@ def test_message_packet_header_structure(self): - nonce: 12 bytes - authdata-size: 2 bytes """ - nonce = bytes(12) - encryption_key = bytes(16) + nonce = Nonce(bytes(12)) + encryption_key = Bytes16(bytes(16)) message = b"\x01\xc2\x01\x01" authdata = encode_message_authdata(NODE_A_ID) @@ -613,8 +619,8 @@ def test_whoareyou_packet_header_structure(self): - authdata: id-nonce (16) + enr-seq (8) = 24 bytes - no message payload """ - nonce = bytes(12) - id_nonce = bytes(16) + nonce = Nonce(bytes(12)) + id_nonce = IdNonce(bytes(16)) enr_seq = SeqNumber(0) authdata = encode_whoareyou_authdata(id_nonce, enr_seq) @@ -640,13 +646,13 @@ def test_handshake_packet_header_structure(self): - authdata: src-id (32) + sig-size (1) + eph-key-size (1) + sig + eph-key + [record] - encrypted message """ - nonce = bytes(12) - encryption_key = bytes(16) + nonce = Nonce(bytes(12)) + encryption_key = Bytes16(bytes(16)) message = b"\x01\xc2\x01\x01" - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33( + bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") ) authdata = encode_handshake_authdata( @@ -723,13 +729,13 @@ def test_aes_gcm_empty_plaintext(self): aad = bytes(32) plaintext = b"" - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) # Empty plaintext should produce just the 16-byte auth tag. assert len(ciphertext) == 16 # Decryption should recover empty plaintext. - decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) assert decrypted == b"" def test_aes_gcm_large_plaintext(self): @@ -737,13 +743,13 @@ def test_aes_gcm_large_plaintext(self): aad = bytes(32) plaintext = bytes(1024) # 1KB of zeros. - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) # Ciphertext = plaintext length + 16-byte tag. assert len(ciphertext) == len(plaintext) + 16 # Decryption should recover original plaintext. - decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) + decrypted = aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, ciphertext, aad) assert decrypted == plaintext def test_aes_gcm_tampered_ciphertext_fails(self): @@ -751,7 +757,7 @@ def test_aes_gcm_tampered_ciphertext_fails(self): aad = bytes(32) plaintext = b"secret message" - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) + ciphertext = aes_gcm_encrypt(SPEC_AES_KEY, SPEC_AES_NONCE, plaintext, aad) # Tamper with ciphertext by flipping a bit. tampered = bytearray(ciphertext) @@ -760,7 +766,7 @@ def test_aes_gcm_tampered_ciphertext_fails(self): # Decryption of tampered ciphertext should fail with InvalidTag. with pytest.raises(InvalidTag): - aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), tampered, aad) + aes_gcm_decrypt(SPEC_AES_KEY, SPEC_AES_NONCE, tampered, aad) class TestSpecPacketPayloadDecryption: @@ -768,7 +774,7 @@ class TestSpecPacketPayloadDecryption: def test_message_packet_encrypt_decrypt_roundtrip(self): """Encrypt a message in a packet and decrypt using message_ad from decode.""" - nonce = bytes(12) + nonce = Nonce(bytes(12)) authdata = encode_message_authdata(NODE_A_ID) @@ -785,16 +791,16 @@ def test_message_packet_encrypt_decrypt_roundtrip(self): header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) # Decrypt using message_ad as AAD. - decrypted = decrypt_message(SPEC_INITIATOR_KEY, bytes(header.nonce), ciphertext, message_ad) + decrypted = decrypt_message(SPEC_INITIATOR_KEY, header.nonce, ciphertext, message_ad) assert decrypted == SPEC_PING_PLAINTEXT def test_handshake_packet_encrypt_decrypt_roundtrip(self): """Handshake packet encrypts and decrypts using correct AAD.""" - nonce = bytes(12) + nonce = Nonce(bytes(12)) - id_signature = bytes(64) - eph_pubkey = bytes.fromhex( - "039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5" + id_signature = Bytes64(bytes(64)) + eph_pubkey = Bytes33( + bytes.fromhex("039a003ba6517b473fa0cd74aefe99dadfdb34627f90fec6362df85803908f53a5") ) authdata = encode_handshake_authdata( @@ -817,7 +823,7 @@ def test_handshake_packet_encrypt_decrypt_roundtrip(self): header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) # Decrypt using message_ad as AAD. - decrypted = decrypt_message(SPEC_INITIATOR_KEY, bytes(header.nonce), ciphertext, message_ad) + decrypted = decrypt_message(SPEC_INITIATOR_KEY, header.nonce, ciphertext, message_ad) assert decrypted == SPEC_PING_PLAINTEXT @@ -826,17 +832,12 @@ class TestRoutingWithTestVectorNodeIds: def test_xor_distance_is_symmetric(self): """XOR distance between test vector nodes is symmetric and non-zero.""" - node_a = NodeId(NODE_A_ID) - node_b = NodeId(NODE_B_ID) - - distance = xor_distance(node_a, node_b) + distance = xor_distance(NODE_A_ID, NODE_B_ID) assert distance > 0 - assert xor_distance(node_a, node_b) == xor_distance(node_b, node_a) + assert xor_distance(NODE_A_ID, NODE_B_ID) == xor_distance(NODE_B_ID, NODE_A_ID) def test_log2_distance_is_high(self): """Log2 distance between test vector nodes is high (differ in high bits).""" - node_a = NodeId(NODE_A_ID) - node_b = NodeId(NODE_B_ID) - log_dist = log2_distance(node_a, node_b) + log_dist = log2_distance(NODE_A_ID, NODE_B_ID) assert log_dist > Distance(200) diff --git a/tests/lean_spec/subspecs/networking/enr/test_eth2.py b/tests/lean_spec/subspecs/networking/enr/test_eth2.py index fb437a79..28d945ed 100644 --- a/tests/lean_spec/subspecs/networking/enr/test_eth2.py +++ b/tests/lean_spec/subspecs/networking/enr/test_eth2.py @@ -10,8 +10,8 @@ AttestationSubnets, SyncCommitteeSubnets, ) +from lean_spec.subspecs.networking.types import ForkDigest, Version from lean_spec.types import Uint64 -from lean_spec.types.byte_arrays import Bytes4 class TestEth2Data: @@ -20,17 +20,17 @@ class TestEth2Data: def test_create_eth2_data(self) -> None: """Eth2Data can be created with valid parameters.""" data = Eth2Data( - fork_digest=Bytes4(b"\x12\x34\x56\x78"), - next_fork_version=Bytes4(b"\x02\x00\x00\x00"), + fork_digest=ForkDigest(b"\x12\x34\x56\x78"), + next_fork_version=Version(b"\x02\x00\x00\x00"), next_fork_epoch=Uint64(194048), ) - assert data.fork_digest == Bytes4(b"\x12\x34\x56\x78") + assert data.fork_digest == ForkDigest(b"\x12\x34\x56\x78") assert data.next_fork_epoch == Uint64(194048) def test_no_scheduled_fork_factory(self) -> None: """no_scheduled_fork factory creates correct data.""" - digest = Bytes4(b"\xab\xcd\xef\x01") - version = Bytes4(b"\x01\x00\x00\x00") + digest = ForkDigest(b"\xab\xcd\xef\x01") + version = Version(b"\x01\x00\x00\x00") data = Eth2Data.no_scheduled_fork(digest, version) assert data.fork_digest == digest @@ -40,12 +40,12 @@ def test_no_scheduled_fork_factory(self) -> None: def test_eth2_data_immutable(self) -> None: """Eth2Data is immutable (frozen).""" data = Eth2Data( - fork_digest=Bytes4(b"\x12\x34\x56\x78"), - next_fork_version=Bytes4(b"\x02\x00\x00\x00"), + fork_digest=ForkDigest(b"\x12\x34\x56\x78"), + next_fork_version=Version(b"\x02\x00\x00\x00"), next_fork_epoch=Uint64(0), ) with pytest.raises(ValidationError): - data.fork_digest = Bytes4(b"\x00\x00\x00\x00") + data.fork_digest = ForkDigest(b"\x00\x00\x00\x00") def test_far_future_epoch_value(self) -> None: """FAR_FUTURE_EPOCH is max uint64.""" @@ -189,7 +189,7 @@ def test_from_subnet_ids_with_duplicates(self) -> None: """from_subnet_ids handles duplicates correctly.""" subnets = SyncCommitteeSubnets.from_subnet_ids([1, 1, 1, 3]) assert subnets.subscription_count() == 2 - assert subnets.subscribed_subnets() == [1, 3] + assert subnets.subscribed_subnets() == [SubnetId(1), SubnetId(3)] def test_from_subnet_ids_invalid(self) -> None: """from_subnet_ids() raises for invalid subnet IDs.""" @@ -202,7 +202,7 @@ def test_from_subnet_ids_invalid(self) -> None: def test_subscribed_subnets(self) -> None: """subscribed_subnets() returns correct list.""" subnets = SyncCommitteeSubnets.from_subnet_ids([1, 3]) - assert subnets.subscribed_subnets() == [1, 3] + assert subnets.subscribed_subnets() == [SubnetId(1), SubnetId(3)] def test_subscription_count(self) -> None: """subscription_count() returns correct count.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py b/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py index b3d3cfa1..0b466c55 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_cache_edge_cases.py @@ -6,7 +6,7 @@ from lean_spec.subspecs.networking.gossipsub.mcache import MessageCache, SeenCache from lean_spec.subspecs.networking.gossipsub.message import GossipsubMessage -from lean_spec.types import Bytes20 +from lean_spec.subspecs.networking.gossipsub.types import MessageId class TestMessageCacheShift: @@ -158,7 +158,7 @@ def test_get_retrieves_cached_message(self) -> None: def test_get_returns_none_for_unknown(self) -> None: """get() returns None for an unknown message ID.""" cache = MessageCache() - assert cache.get(Bytes20(b"\x00" * 20)) is None + assert cache.get(MessageId(b"\x00" * 20)) is None def test_put_duplicate_returns_false(self) -> None: """Putting the same message twice returns False on second call.""" @@ -176,7 +176,7 @@ def test_has_method(self) -> None: cache.put("t", msg) assert cache.has(msg.id) - assert not cache.has(Bytes20(b"\x00" * 20)) + assert not cache.has(MessageId(b"\x00" * 20)) class TestSeenCache: @@ -185,13 +185,13 @@ class TestSeenCache: def test_add_returns_true_for_new(self) -> None: """add() returns True for a new message ID.""" seen = SeenCache(ttl_seconds=120) - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") assert seen.add(msg_id, time.time()) is True def test_add_returns_false_for_duplicate(self) -> None: """add() returns False for an already-seen message ID.""" seen = SeenCache(ttl_seconds=120) - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") seen.add(msg_id, time.time()) assert seen.add(msg_id, time.time()) is False @@ -200,8 +200,8 @@ def test_cleanup_removes_expired(self) -> None: seen = SeenCache(ttl_seconds=10) now = time.time() - old_id = Bytes20(b"aaaaaaaaaaaaaaaaaaaa") - fresh_id = Bytes20(b"bbbbbbbbbbbbbbbbbbbb") + old_id = MessageId(b"aaaaaaaaaaaaaaaaaaaa") + fresh_id = MessageId(b"bbbbbbbbbbbbbbbbbbbb") seen.add(old_id, now - 20) seen.add(fresh_id, now) @@ -214,7 +214,7 @@ def test_cleanup_no_expired(self) -> None: """cleanup() with no expired entries removes nothing.""" seen = SeenCache(ttl_seconds=120) now = time.time() - seen.add(Bytes20(b"12345678901234567890"), now) + seen.add(MessageId(b"12345678901234567890"), now) removed = seen.cleanup(now) assert removed == 0 @@ -224,7 +224,7 @@ def test_clear_empties_all(self) -> None: """clear() removes all entries.""" seen = SeenCache() for i in range(5): - seen.add(Bytes20(f"x{i:019d}".encode()), time.time()) + seen.add(MessageId(f"x{i:019d}".encode()), time.time()) assert len(seen) == 5 seen.clear() @@ -233,11 +233,11 @@ def test_clear_empties_all(self) -> None: def test_has_method(self) -> None: """The has() method works for seen message IDs.""" seen = SeenCache() - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") seen.add(msg_id, time.time()) assert seen.has(msg_id) - assert not seen.has(Bytes20(b"\x00" * 20)) + assert not seen.has(MessageId(b"\x00" * 20)) class TestGossipsubMessageId: @@ -262,10 +262,10 @@ def test_id_differs_with_different_topic(self) -> None: assert msg1.id != msg2.id def test_id_is_20_bytes(self) -> None: - """Message ID is exactly 20 bytes (Bytes20).""" + """Message ID is exactly 20 bytes.""" msg = GossipsubMessage(topic=b"t", raw_data=b"d") assert len(msg.id) == 20 - assert isinstance(msg.id, Bytes20) + assert isinstance(msg.id, MessageId) def test_id_is_cached(self) -> None: """The ID is computed once and reused on subsequent accesses.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py b/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py index f3c2d589..209a13cb 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_handlers.py @@ -24,7 +24,7 @@ Message, SubOpts, ) -from lean_spec.types import Bytes20 +from lean_spec.subspecs.networking.gossipsub.types import MessageId from .conftest import add_peer, make_behavior, make_peer @@ -217,7 +217,7 @@ async def test_ihave_ignores_seen(self) -> None: """IHAVE for already-seen messages does not trigger IWANT.""" behavior, capture = make_behavior() peer_id = add_peer(behavior, "peer1") - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") # Mark as seen behavior.seen_cache.add(msg_id, time.time()) @@ -233,7 +233,7 @@ async def test_ihave_partial_seen(self) -> None: behavior, capture = make_behavior() peer_id = add_peer(behavior, "peer1") - seen_id = Bytes20(b"seen_msg_id_1234seen") + seen_id = MessageId(b"seen_msg_id_1234seen") unseen_id = b"unseen_msg_id_12unse" behavior.seen_cache.add(seen_id, time.time()) @@ -535,7 +535,7 @@ def test_idontwant_populates_peer_set(self) -> None: state = behavior._peers[peer_id] for mid in msg_ids: - assert Bytes20(mid) in state.dont_want_ids + assert MessageId(mid) in state.dont_want_ids def test_idontwant_unknown_peer(self) -> None: """IDONTWANT for unknown peer is silently ignored.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py b/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py index 4743831e..02bb934d 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/test_heartbeat.py @@ -20,7 +20,7 @@ ControlMessage, ControlPrune, ) -from lean_spec.types import Bytes20 +from lean_spec.subspecs.networking.gossipsub.types import MessageId from .conftest import add_peer, make_behavior, make_peer @@ -255,7 +255,7 @@ async def test_cleans_seen_cache(self) -> None: behavior, _ = make_behavior() behavior.seen_cache = SeenCache(ttl_seconds=1) - msg_id = Bytes20(b"12345678901234567890") + msg_id = MessageId(b"12345678901234567890") behavior.seen_cache.add(msg_id, time.time() - 10) # Already expired await behavior._heartbeat() @@ -310,7 +310,7 @@ async def test_clears_idontwant_sets(self) -> None: """Heartbeat clears per-peer IDONTWANT sets.""" behavior, _ = make_behavior() pid = add_peer(behavior, "peer1") - behavior._peers[pid].dont_want_ids.add(Bytes20(b"12345678901234567890")) + behavior._peers[pid].dont_want_ids.add(MessageId(b"12345678901234567890")) assert len(behavior._peers[pid].dont_want_ids) == 1 diff --git a/tests/lean_spec/subspecs/networking/test_network_service.py b/tests/lean_spec/subspecs/networking/test_network_service.py index 9e657f4d..027fb75e 100644 --- a/tests/lean_spec/subspecs/networking/test_network_service.py +++ b/tests/lean_spec/subspecs/networking/test_network_service.py @@ -171,7 +171,7 @@ async def test_attestation_processed_by_store( attestation = SignedAttestation( validator_id=ValidatorIndex(42), - message=AttestationData( + data=AttestationData( slot=Slot(1), head=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), target=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), @@ -218,7 +218,7 @@ async def test_attestation_ignored_in_idle_state( attestation = SignedAttestation( validator_id=ValidatorIndex(99), - message=AttestationData( + data=AttestationData( slot=Slot(1), head=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), target=Checkpoint(root=Bytes32.zero(), slot=Slot(1)), diff --git a/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py b/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py index baa7455f..f2a4a6c6 100644 --- a/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py +++ b/tests/lean_spec/subspecs/networking/transport/identity/test_keypair.py @@ -1,7 +1,5 @@ """Tests for secp256k1 identity keypair.""" -import pytest - from lean_spec.subspecs.networking.transport.identity import ( IdentityKeypair, verify_signature, @@ -34,21 +32,10 @@ def test_generate_unique(self) -> None: def test_from_bytes_roundtrip(self) -> None: """Keypair can be loaded from raw bytes.""" original = IdentityKeypair.generate() - private_bytes = original.private_key_bytes() - - restored = IdentityKeypair.from_bytes(private_bytes) - + restored = IdentityKeypair.from_bytes(original.private_key_bytes()) assert restored.public_key_bytes() == original.public_key_bytes() assert restored.private_key_bytes() == original.private_key_bytes() - def test_from_bytes_invalid_length(self) -> None: - """Loading from invalid bytes raises ValueError.""" - with pytest.raises(ValueError, match="Expected 32 bytes"): - IdentityKeypair.from_bytes(b"\x00" * 16) - - with pytest.raises(ValueError, match="Expected 32 bytes"): - IdentityKeypair.from_bytes(b"\x00" * 64) - def test_sign_and_verify(self) -> None: """Signatures can be verified.""" keypair = IdentityKeypair.generate() diff --git a/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py b/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py index e67aaefa..3c519946 100644 --- a/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py +++ b/tests/lean_spec/subspecs/networking/transport/identity/test_signature.py @@ -8,6 +8,7 @@ create_identity_proof, verify_identity_proof, ) +from lean_spec.types import Bytes32 class TestIdentityProof: @@ -16,25 +17,25 @@ class TestIdentityProof: def test_create_and_verify(self) -> None: """Identity proof can be verified.""" identity_key = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key, noise_public_key) + proof = create_identity_proof(identity_key, public_key) assert verify_identity_proof( identity_key.public_key_bytes(), - noise_public_key, + public_key, proof, ) - def test_verify_wrong_noise_key(self) -> None: - """Verification fails with wrong Noise key.""" + def test_verify_wrong_key(self) -> None: + """Verification fails with wrong public key.""" identity_key = IdentityKeypair.generate() - noise_public_key = os.urandom(32) - wrong_noise_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) + wrong_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key, noise_public_key) + proof = create_identity_proof(identity_key, public_key) assert not verify_identity_proof( identity_key.public_key_bytes(), - wrong_noise_key, + wrong_key, proof, ) @@ -42,45 +43,45 @@ def test_verify_wrong_identity_key(self) -> None: """Verification fails with wrong identity key.""" identity_key1 = IdentityKeypair.generate() identity_key2 = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key1, noise_public_key) + proof = create_identity_proof(identity_key1, public_key) assert not verify_identity_proof( identity_key2.public_key_bytes(), - noise_public_key, + public_key, proof, ) def test_proof_is_deterministic(self) -> None: """Same inputs produce same proof format (but not same bytes due to ECDSA k).""" identity_key = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof1 = create_identity_proof(identity_key, noise_public_key) - proof2 = create_identity_proof(identity_key, noise_public_key) + proof1 = create_identity_proof(identity_key, public_key) + proof2 = create_identity_proof(identity_key, public_key) - assert verify_identity_proof(identity_key.public_key_bytes(), noise_public_key, proof1) - assert verify_identity_proof(identity_key.public_key_bytes(), noise_public_key, proof2) + assert verify_identity_proof(identity_key.public_key_bytes(), public_key, proof1) + assert verify_identity_proof(identity_key.public_key_bytes(), public_key, proof2) def test_noise_identity_prefix(self) -> None: """NOISE_IDENTITY_PREFIX matches libp2p-noise spec.""" assert NOISE_IDENTITY_PREFIX == b"noise-libp2p-static-key:" - def test_proof_binds_identity_to_noise_key(self) -> None: + def test_proof_binds_identity_to_key(self) -> None: """Proof prevents identity key substitution.""" identity_key_real = IdentityKeypair.generate() identity_key_attacker = IdentityKeypair.generate() - noise_public_key = os.urandom(32) + public_key = Bytes32(os.urandom(32)) - proof = create_identity_proof(identity_key_real, noise_public_key) + proof = create_identity_proof(identity_key_real, public_key) assert verify_identity_proof( identity_key_real.public_key_bytes(), - noise_public_key, + public_key, proof, ) assert not verify_identity_proof( identity_key_attacker.public_key_bytes(), - noise_public_key, + public_key, proof, ) diff --git a/tests/lean_spec/subspecs/networking/transport/test_peer_id.py b/tests/lean_spec/subspecs/networking/transport/test_peer_id.py index a49002d8..67816801 100644 --- a/tests/lean_spec/subspecs/networking/transport/test_peer_id.py +++ b/tests/lean_spec/subspecs/networking/transport/test_peer_id.py @@ -23,6 +23,7 @@ PeerId, PublicKeyProto, ) +from lean_spec.types import Bytes33 # Protobuf tag constants for test assertions _PROTOBUF_TAG_TYPE = 0x08 # (1 << 3) | 0 = field 1, varint @@ -273,14 +274,6 @@ def test_different_keys_different_peerids(self) -> None: assert str(peer_id1) != str(peer_id2) - def test_from_secp256k1_invalid_length(self) -> None: - """from_secp256k1 rejects invalid key lengths.""" - with pytest.raises(ValueError, match="must be 33 bytes"): - PeerId.from_secp256k1(bytes(32)) - - with pytest.raises(ValueError, match="must be 33 bytes"): - PeerId.from_secp256k1(bytes(34)) - class TestPeerIdFormat: """Tests for PeerId format and structure.""" @@ -482,8 +475,8 @@ def test_known_secp256k1_peer_id(self) -> None: This matches the libp2p spec test vector. """ # From spec: 08021221037777e994e452c21604f91de093ce415f5432f701dd8cd1a7a6fea0e630bfca99 - key_data = bytes.fromhex( - "037777e994e452c21604f91de093ce415f5432f701dd8cd1a7a6fea0e630bfca99" + key_data = Bytes33( + bytes.fromhex("037777e994e452c21604f91de093ce415f5432f701dd8cd1a7a6fea0e630bfca99") ) peer_id = PeerId.from_secp256k1(key_data) diff --git a/tests/lean_spec/subspecs/ssz/test_signed_attestation.py b/tests/lean_spec/subspecs/ssz/test_signed_attestation.py index ab6965d1..17fe747a 100644 --- a/tests/lean_spec/subspecs/ssz/test_signed_attestation.py +++ b/tests/lean_spec/subspecs/ssz/test_signed_attestation.py @@ -17,7 +17,7 @@ def test_encode_decode_signed_attestation_roundtrip() -> None: ) signed_attestation = SignedAttestation( validator_id=ValidatorIndex(0), - message=attestation_data, + data=attestation_data, signature=Signature( path=HashTreeOpening(siblings=HashDigestList(data=[])), rho=Randomness(data=[Fp(0) for _ in range(PROD_CONFIG.RAND_LEN_FE)]), diff --git a/tests/lean_spec/subspecs/sync/test_service.py b/tests/lean_spec/subspecs/sync/test_service.py index 1b5b795c..54d025e2 100644 --- a/tests/lean_spec/subspecs/sync/test_service.py +++ b/tests/lean_spec/subspecs/sync/test_service.py @@ -417,7 +417,7 @@ async def test_attestation_buffered_when_block_unknown( original_fn = sync_service.store.on_gossip_attestation def reject_unknown(signed_attestation, *, is_aggregator=False): - if signed_attestation.message.target.root == unknown_root: + if signed_attestation.data.target.root == unknown_root: raise KeyError("Unknown block") return original_fn(signed_attestation, is_aggregator=is_aggregator) diff --git a/tests/lean_spec/subspecs/validator/test_service.py b/tests/lean_spec/subspecs/validator/test_service.py index 0d074a40..7886f9ab 100644 --- a/tests/lean_spec/subspecs/validator/test_service.py +++ b/tests/lean_spec/subspecs/validator/test_service.py @@ -15,6 +15,7 @@ SignedAttestation, SignedBlockWithAttestation, ValidatorIndex, + ValidatorIndices, ) from lean_spec.subspecs.containers.attestation import AggregationBits from lean_spec.subspecs.containers.slot import Slot @@ -576,11 +577,11 @@ async def capture_attestation(attestation: SignedAttestation) -> None: for signed_att in attestations_produced: validator_id = signed_att.validator_id public_key = key_manager.get_public_key(validator_id) - message_bytes = signed_att.message.data_root_bytes() + message_bytes = signed_att.data.data_root_bytes() is_valid = TARGET_SIGNATURE_SCHEME.verify( pk=public_key, - slot=signed_att.message.slot, + slot=signed_att.data.slot, message=message_bytes, sig=signed_att.signature, ) @@ -619,7 +620,7 @@ async def capture_attestation(attestation: SignedAttestation) -> None: expected_source = store.latest_justified for signed_att in attestations_produced: - data = signed_att.message + data = signed_att.data # Verify head checkpoint references the store's head assert data.head.root == expected_head_root @@ -711,7 +712,9 @@ async def test_block_includes_pending_attestations( attestation_map[vid] = attestation_data proof = AggregatedSignatureProof.aggregate( - participants=AggregationBits.from_validator_indices(participants), + participants=AggregationBits.from_validator_indices( + ValidatorIndices(data=participants) + ), public_keys=public_keys, signatures=signatures, message=data_root, @@ -774,7 +777,7 @@ async def test_multiple_slots_produce_different_attestations( } async def capture_attestation(attestation: SignedAttestation) -> None: - attestations_by_slot[attestation.message.slot].append(attestation) + attestations_by_slot[attestation.data.slot].append(attestation) service = ValidatorService( sync_service=real_sync_service, @@ -792,9 +795,9 @@ async def capture_attestation(attestation: SignedAttestation) -> None: # Attestations at each slot should have the correct slot value for att in attestations_by_slot[Slot(1)]: - assert att.message.slot == Slot(1) + assert att.data.slot == Slot(1) for att in attestations_by_slot[Slot(2)]: - assert att.message.slot == Slot(2) + assert att.data.slot == Slot(2) async def test_proposer_does_not_double_attest( self, @@ -921,7 +924,7 @@ async def capture_attestation(attestation: SignedAttestation) -> None: for signed_att in attestations_produced: validator_id = signed_att.validator_id public_key = key_manager.get_public_key(validator_id) - message_bytes = signed_att.message.data_root_bytes() + message_bytes = signed_att.data.data_root_bytes() # Verification must use the same slot that was used for signing is_valid = TARGET_SIGNATURE_SCHEME.verify( From a7ccb03a530c96e04b8514ee347cd76189d0c99d Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 20 Feb 2026 17:35:35 +0100 Subject: [PATCH 2/2] cleanup --- .claude/ralph-loop.local.md | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 .claude/ralph-loop.local.md diff --git a/.claude/ralph-loop.local.md b/.claude/ralph-loop.local.md deleted file mode 100644 index 305fe8f0..00000000 --- a/.claude/ralph-loop.local.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -active: true -iteration: 137 -max_iterations: 0 -completion_promise: "DONE" -started_at: "2026-02-20T14:14:04Z" ---- - -Can you go with the py architect to sanity check the codebase everywhere to check that the code adheres to all the most modern Python principles. It must be extremely lean, clean, and compact so that it truly serves as a minimal running specification/client that is considered a reference worldwide. Therefore, take inspiration from the best repositories in the world, or even from the Python compiler itself, and create something perfect. Of course, as a specification, it's crucial that things are organized correctly and clearly, so we'll prioritize, for example, storing functions in objects rather than isolated functions, while avoiding excessive abstraction to prevent overly complexifying the codebase.