Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ from cryptography.hazmat.primitives.asymmetric import x25519

This keeps code readable and avoids mental overhead of tracking renamed imports.

### Type Annotations

**Never quote type annotations when `from __future__ import annotations` is present.** With future annotations, all annotations are already lazy strings. Adding quotes is redundant and noisy.

Bad:
```python
from __future__ import annotations

def create(cls) -> "Store":
...
```

Good:
```python
from __future__ import annotations

def create(cls) -> Store:
...
```

The only valid use of quoted annotations is in files that do NOT have `from __future__ import annotations` and need a forward reference. Prefer adding the future import instead.

**Prefer narrow domain types over raw builtins.** Use `Bytes32`, `Bytes33`, `Bytes52` etc. instead of `bytes` in signatures. Spec code should never accept or return `bytes` when a more specific type exists.

### Module-Level Constants

Use docstrings (not comments) to document module-level constants. Place the docstring immediately after the assignment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def set_max_slot_default(self) -> Self:
if isinstance(step, BlockStep):
max_slot_value = max(max_slot_value, step.block.slot)
elif isinstance(step, AttestationStep):
max_slot_value = max(max_slot_value, step.attestation.message.slot)
max_slot_value = max(max_slot_value, step.attestation.data.slot)

self.max_slot = max_slot_value

Expand Down Expand Up @@ -219,9 +219,8 @@ def make_fixture(self) -> Self:
#
# The Store is the node's local view of the chain.
# It starts from a trusted anchor (usually genesis).
store = Store.get_forkchoice_store(
anchor_state=self.anchor_state,
anchor_block=self.anchor_block,
store = self.anchor_state.to_forkchoice_store(
self.anchor_block,
validator_id=DEFAULT_VALIDATOR_ID,
)

Expand Down Expand Up @@ -400,7 +399,7 @@ def _build_block_from_spec(
working_store = working_store.on_gossip_attestation(
SignedAttestation(
validator_id=attestation.validator_id,
message=attestation.data,
data=attestation.data,
signature=signature,
),
scheme=LEAN_ENV_TO_SCHEMES[self.lean_env],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from lean_spec.subspecs.containers.checkpoint import Checkpoint
from lean_spec.subspecs.containers.state.state import State
from lean_spec.subspecs.containers.validator import ValidatorIndex
from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices
from lean_spec.subspecs.ssz import hash_tree_root
from lean_spec.subspecs.xmss.aggregation import AggregatedSignatureProof
from lean_spec.types import Bytes32
Expand All @@ -42,7 +42,7 @@ def _create_dummy_aggregated_proof(validator_ids: list[ValidatorIndex]) -> Aggre
so it will fail verification.
"""
return AggregatedSignatureProof(
participants=AggregationBits.from_validator_indices(validator_ids),
participants=AggregationBits.from_validator_indices(ValidatorIndices(data=validator_ids)),
proof_data=ByteListMiB(data=b"\x00" * 32), # Invalid proof bytes
)

Expand Down Expand Up @@ -216,7 +216,9 @@ def _build_block_from_spec(
data_root = attestation_data.data_root_bytes()

# Create aggregated attestation claiming validator_ids as participants
aggregation_bits = AggregationBits.from_validator_indices(invalid_spec.validator_ids)
aggregation_bits = AggregationBits.from_validator_indices(
ValidatorIndices(data=invalid_spec.validator_ids)
)
invalid_aggregated = AggregatedAttestation(
aggregation_bits=aggregation_bits,
data=attestation_data,
Expand All @@ -238,7 +240,9 @@ def _build_block_from_spec(
]
# Create valid aggregated proof from actual signers
valid_proof = AggregatedSignatureProof.aggregate(
participants=AggregationBits.from_validator_indices(invalid_spec.signer_ids),
participants=AggregationBits.from_validator_indices(
ValidatorIndices(data=invalid_spec.signer_ids)
),
public_keys=signer_public_keys,
signatures=signer_signatures,
message=data_root,
Expand Down
3 changes: 1 addition & 2 deletions src/lean_spec/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from lean_spec.subspecs.containers.block.types import AggregatedAttestations
from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.validator import SubnetId
from lean_spec.subspecs.forkchoice import Store
from lean_spec.subspecs.genesis import GenesisConfig
from lean_spec.subspecs.networking.client import LiveNetworkEventSource
from lean_spec.subspecs.networking.enr import ENR
Expand Down Expand Up @@ -271,7 +270,7 @@ async def _init_from_checkpoint(
# The store treats this as the new "genesis" for fork choice purposes.
# All blocks before the checkpoint are effectively pruned.
validator_id = validator_registry.primary_index() if validator_registry else None
store = Store.get_forkchoice_store(state, anchor_block, validator_id)
store = state.to_forkchoice_store(anchor_block, validator_id)
logger.info(
"Initialized from checkpoint at slot %d (finalized=%s)",
state.slot,
Expand Down
2 changes: 1 addition & 1 deletion src/lean_spec/subspecs/chain/clock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
"""

import asyncio
from collections.abc import Callable
from dataclasses import dataclass
from time import time as wall_time
from typing import Callable

from lean_spec.subspecs.containers.slot import Slot
from lean_spec.types import Uint64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ class AggregationBits(BaseBitlist):
LIMIT = int(VALIDATOR_REGISTRY_LIMIT)

@classmethod
def from_validator_indices(
cls, indices: "ValidatorIndices | list[ValidatorIndex]"
) -> AggregationBits:
def from_validator_indices(cls, indices: ValidatorIndices) -> AggregationBits:
"""
Construct aggregation bits from a set of validator indices.

Args:
indices: Validator indices to set in the bitlist. Accepts either
a ValidatorIndices collection or a plain list of ValidatorIndex.
indices: Validator indices to set in the bitlist.

Returns:
AggregationBits with the corresponding indices set to True.
Expand All @@ -36,8 +33,7 @@ def from_validator_indices(
AssertionError: If no indices are provided.
AssertionError: If any index is outside the supported LIMIT.
"""
# Extract list from ValidatorIndices if needed
index_list = indices.data if isinstance(indices, ValidatorIndices) else indices
index_list = indices.data

# Require at least one validator for a valid aggregation.
if not index_list:
Expand All @@ -57,7 +53,7 @@ def from_validator_indices(
# - False elsewhere.
return cls(data=[Boolean(i in ids) for i in range(max_id + 1)])

def to_validator_indices(self) -> "ValidatorIndices":
def to_validator_indices(self) -> ValidatorIndices:
"""
Extract all validator indices encoded in these aggregation bits.

Expand Down
14 changes: 5 additions & 9 deletions src/lean_spec/subspecs/containers/attestation/attestation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict

from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.validator import ValidatorIndex
from lean_spec.subspecs.containers.validator import ValidatorIndex, ValidatorIndices
from lean_spec.subspecs.ssz import hash_tree_root
from lean_spec.types import Bytes32, Container

Expand Down Expand Up @@ -56,15 +56,9 @@ class Attestation(Container):
"""The attestation data produced by the validator."""


class SignedAttestation(Container):
class SignedAttestation(Attestation):
"""Validator attestation bundled with its signature."""

validator_id: ValidatorIndex
"""The index of the validator making the attestation."""

message: AttestationData
"""The attestation message signed by the validator."""

signature: Signature
"""Signature aggregation produced by the leanVM (SNARKs in the future)."""

Expand Down Expand Up @@ -103,7 +97,9 @@ def aggregate_by_data(

return [
cls(
aggregation_bits=AggregationBits.from_validator_indices(validator_ids),
aggregation_bits=AggregationBits.from_validator_indices(
ValidatorIndices(data=validator_ids)
),
data=data,
)
for data, validator_ids in data_to_validator_ids.items()
Expand Down
92 changes: 82 additions & 10 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

from typing import AbstractSet, Collection, Iterable
from typing import TYPE_CHECKING, AbstractSet, Collection, Iterable

from lean_spec.subspecs.chain.clock import Interval
from lean_spec.subspecs.chain.config import INTERVALS_PER_SLOT
from lean_spec.subspecs.ssz.hash import hash_tree_root
from lean_spec.subspecs.xmss.aggregation import (
AggregatedSignatureProof,
Expand All @@ -24,7 +26,7 @@
from ..checkpoint import Checkpoint
from ..config import Config
from ..slot import Slot
from ..validator import ValidatorIndex
from ..validator import ValidatorIndex, ValidatorIndices
from .types import (
HistoricalBlockHashes,
JustificationRoots,
Expand All @@ -33,6 +35,9 @@
Validators,
)

if TYPE_CHECKING:
from lean_spec.subspecs.forkchoice import Store


class State(Container):
"""The main consensus state object."""
Expand Down Expand Up @@ -73,7 +78,7 @@ class State(Container):
"""A bitlist of validators who participated in justifications."""

@classmethod
def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "State":
def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> State:
"""
Generate a genesis state with empty history and proper initial values.

Expand Down Expand Up @@ -117,7 +122,72 @@ def generate_genesis(cls, genesis_time: Uint64, validators: Validators) -> "Stat
justifications_validators=JustificationValidators(data=[]),
)

def process_slots(self, target_slot: Slot) -> "State":
def to_forkchoice_store(
self,
anchor_block: Block,
validator_id: ValidatorIndex | None,
) -> Store:
"""
Initialize a forkchoice store from this state and an anchor block.

The anchor block and this state form the starting point for fork choice.
Both are treated as justified and finalized.

Args:
anchor_block: A trusted block (e.g. genesis or checkpoint).
validator_id: Index of the validator running this store.

Returns:
A new Store instance, ready to accept blocks and attestations.

Raises:
AssertionError:
If the anchor block's state root does not match the hash
of this state.
"""
from lean_spec.subspecs.forkchoice import Store

# Compute the SSZ root of this state.
#
# This is the canonical hash that should appear in the block's state root.
computed_state_root = hash_tree_root(self)

# Check that the block actually points to this state.
#
# If this fails, the caller has supplied inconsistent inputs.
assert anchor_block.state_root == computed_state_root, (
"Anchor block state root must match anchor state hash"
)

# Compute the SSZ root of the anchor block itself.
#
# This root will be used as:
# - the key in the blocks/states maps,
# - the initial head,
# - the root of the initial checkpoints.
anchor_root = hash_tree_root(anchor_block)

# Read the slot at which the anchor block was proposed.
anchor_slot = anchor_block.slot

# Initialize checkpoints from this state.
#
# We explicitly set the root to the anchor block root.
# The state internally might have zero-hash checkpoints (if genesis),
# but the Store must treat the anchor block as the justified/finalized point.
return Store(
time=Interval(anchor_slot * INTERVALS_PER_SLOT),
config=self.config,
head=anchor_root,
safe_target=anchor_root,
latest_justified=self.latest_justified.model_copy(update={"root": anchor_root}),
latest_finalized=self.latest_finalized.model_copy(update={"root": anchor_root}),
blocks={anchor_root: anchor_block},
states={anchor_root: self},
validator_id=validator_id,
)

def process_slots(self, target_slot: Slot) -> State:
"""
Advance the state through empty slots up to, but not including, target_slot.

Expand Down Expand Up @@ -192,7 +262,7 @@ def process_slots(self, target_slot: Slot) -> "State":
# Reached the target slot. Return the advanced state.
return state

def process_block_header(self, block: Block) -> "State":
def process_block_header(self, block: Block) -> State:
"""
Validate the block header and update header-linked state.

Expand Down Expand Up @@ -341,7 +411,7 @@ def process_block_header(self, block: Block) -> "State":
}
)

def process_block(self, block: Block) -> "State":
def process_block(self, block: Block) -> State:
"""
Apply full block processing including header and body.

Expand All @@ -368,7 +438,7 @@ def process_block(self, block: Block) -> "State":
def process_attestations(
self,
attestations: Iterable[AggregatedAttestation],
) -> "State":
) -> State:
"""
Apply attestations and update justification/finalization
according to the Lean Consensus 3SF-mini rules.
Expand Down Expand Up @@ -617,7 +687,7 @@ def process_attestations(
}
)

def state_transition(self, block: Block, valid_signatures: bool = True) -> "State":
def state_transition(self, block: Block, valid_signatures: bool = True) -> State:
"""
Apply the complete state transition function for a block.

Expand Down Expand Up @@ -670,7 +740,7 @@ def build_block(
available_attestations: Iterable[Attestation] | None = None,
known_block_roots: AbstractSet[Bytes32] | None = None,
aggregated_payloads: dict[SignatureKey, list[AggregatedSignatureProof]] | None = None,
) -> tuple[Block, "State", list[AggregatedAttestation], list[AggregatedSignatureProof]]:
) -> tuple[Block, State, list[AggregatedAttestation], list[AggregatedSignatureProof]]:
"""
Build a valid block on top of this state.

Expand Down Expand Up @@ -858,7 +928,9 @@ def aggregate_gossip_signatures(
# The aggregation combines multiple XMSS signatures into a single
# compact proof that can verify all participants signed the message.
if gossip_ids:
participants = AggregationBits.from_validator_indices(gossip_ids)
participants = AggregationBits.from_validator_indices(
ValidatorIndices(data=gossip_ids)
)
proof = AggregatedSignatureProof.aggregate(
participants=participants,
public_keys=gossip_keys,
Expand Down
2 changes: 1 addition & 1 deletion src/lean_spec/subspecs/containers/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def with_justified(
finalized_slot: Slot,
target_slot: Slot,
value: Boolean,
) -> "JustifiedSlots":
) -> JustifiedSlots:
"""
Return a new bitfield with the justification status updated.

Expand Down
Loading
Loading