diff --git a/pyproject.toml b/pyproject.toml index 16cb9a1..d0fabba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,7 @@ known-first-party = ["openai", "tests"] "scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] "tests/instrument/**.py" = ["T201", "T203", "ARG"] +"tests/attestation/**.py" = ["T201", "T203", "ARG"] "examples/**.py" = ["T201", "T203"] "src/layerlens/cli/**" = ["T201", "T203"] "src/layerlens/instrument/adapters/frameworks/langchain.py" = ["ARG002"] diff --git a/src/layerlens/attestation/__init__.py b/src/layerlens/attestation/__init__.py new file mode 100644 index 0000000..eda36cd --- /dev/null +++ b/src/layerlens/attestation/__init__.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from ._hash import compute_hash +from ._chain import HashChain +from ._verify import ( + TamperingResult, + ChainVerification, + TrialVerification, + verify_chain, + verify_trial, + detect_tampering, +) +from ._signing import hmac_sign, hmac_verify +from ._envelope import HashScope, AttestationEnvelope + +__all__ = [ + "AttestationEnvelope", + "ChainVerification", + "HashChain", + "HashScope", + "TamperingResult", + "TrialVerification", + "compute_hash", + "detect_tampering", + "hmac_sign", + "hmac_verify", + "verify_chain", + "verify_trial", +] diff --git a/src/layerlens/attestation/_chain.py b/src/layerlens/attestation/_chain.py new file mode 100644 index 0000000..ae6a3d1 --- /dev/null +++ b/src/layerlens/attestation/_chain.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from copy import copy +from typing import Any, Dict, List, Optional + +from ._hash import compute_hash +from ._envelope import HashScope, AttestationEnvelope + + +class HashChain: + """Builds a linear hash chain over a sequence of events. + + Each event is hashed and linked to the previous hash, forming + a tamper-evident chain. If any event is modified after the fact, + the chain breaks at that point. + + Signing is handled server-side at trace ingestion. The SDK builds + the hash chain for integrity; the backend signs for authenticity. + """ + + def __init__(self) -> None: + self._chain: List[AttestationEnvelope] = [] + self._last_hash: Optional[str] = None + self._terminated: bool = False + self._terminate_reason: Optional[str] = None + + @property + def envelopes(self) -> List[AttestationEnvelope]: + return [copy(e) for e in self._chain] + + @property + def is_terminated(self) -> bool: + return self._terminated + + def _check_active(self) -> None: + if self._terminated: + raise RuntimeError(f"Hash chain terminated: {self._terminate_reason}. No further events can be added.") + + def add_event(self, data: Dict[str, Any]) -> AttestationEnvelope: + """Hash an event and append it to the chain.""" + self._check_active() + # Include previous_hash in the hashed payload for chaining + payload = {**data, "_previous_hash": self._last_hash} + event_hash = compute_hash(payload) + envelope = AttestationEnvelope( + hash=event_hash, + scope=HashScope.EVENT, + previous_hash=self._last_hash, + ) + self._chain.append(envelope) + self._last_hash = event_hash + return envelope + + def terminate(self, reason: str) -> None: + """Permanently stop the chain. No further events or finalization allowed.""" + self._terminated = True + self._terminate_reason = reason + + def finalize(self) -> AttestationEnvelope: + """Compute a trial-level root hash over all event hashes and seal the chain.""" + if self._terminated: + raise RuntimeError( + f"Cannot finalize terminated hash chain. Trial is non-attestable due to: {self._terminate_reason}" + ) + if not self._chain: + raise RuntimeError("Cannot finalize empty hash chain.") + event_hashes = [e.hash for e in self._chain] + root_hash = compute_hash({"event_hashes": event_hashes}) + trial_envelope = AttestationEnvelope( + hash=root_hash, + scope=HashScope.TRIAL, + previous_hash=self._last_hash, + ) + # Seal — no more events after finalization + self._terminated = True + self._terminate_reason = "chain finalized" + return trial_envelope + + def to_dict(self) -> Dict[str, Any]: + """Serialize the chain for inclusion in trace uploads.""" + result: Dict[str, Any] = { + "events": [e.to_dict() for e in self._chain], + } + # Only include termination details when the chain was stopped + # due to a policy violation (not normal finalization). + if self._terminated and self._terminate_reason != "chain finalized": + result["terminated_reason"] = self._terminate_reason + return result diff --git a/src/layerlens/attestation/_envelope.py b/src/layerlens/attestation/_envelope.py new file mode 100644 index 0000000..7fbc978 --- /dev/null +++ b/src/layerlens/attestation/_envelope.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, Optional +from datetime import datetime, timezone +from dataclasses import field, dataclass + + +class HashScope(Enum): + """Level at which a hash was computed.""" + + EVENT = "event" + TRIAL = "trial" + + +@dataclass +class AttestationEnvelope: + """Single entry in a hash chain.""" + + hash: str + scope: HashScope + previous_hash: Optional[str] = None + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + signature: Optional[str] = None + signing_key_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + d: Dict[str, Any] = { + "hash": self.hash, + "scope": self.scope.value, + "previous_hash": self.previous_hash, + "timestamp": self.timestamp.isoformat(), + } + if self.signature is not None: + d["signature"] = self.signature + if self.signing_key_id is not None: + d["signing_key_id"] = self.signing_key_id + return d diff --git a/src/layerlens/attestation/_hash.py b/src/layerlens/attestation/_hash.py new file mode 100644 index 0000000..f6284cf --- /dev/null +++ b/src/layerlens/attestation/_hash.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import hashlib +from enum import Enum +from typing import Any +from datetime import datetime +from dataclasses import asdict + + +def _json_default(obj: Any) -> Any: + """Handle non-standard types for canonical JSON serialization.""" + if isinstance(obj, datetime): + return obj.isoformat() + if isinstance(obj, Enum): + return obj.value + if hasattr(obj, "to_dict"): + return obj.to_dict() + if hasattr(obj, "__dataclass_fields__"): + return asdict(obj) + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +def canonical_json(data: Any) -> str: + """Serialize data to canonical JSON: sorted keys, compact, deterministic.""" + return json.dumps( + data, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=True, + default=_json_default, + ) + + +def compute_hash(data: Any) -> str: + """Compute SHA-256 hash of canonicalized data. Returns 'sha256:<64 hex chars>'.""" + raw = canonical_json(data) + digest = hashlib.sha256(raw.encode("utf-8")).hexdigest() + return f"sha256:{digest}" diff --git a/src/layerlens/attestation/_signing.py b/src/layerlens/attestation/_signing.py new file mode 100644 index 0000000..6164c14 --- /dev/null +++ b/src/layerlens/attestation/_signing.py @@ -0,0 +1,19 @@ +"""HMAC-SHA256 signing for attestation envelopes.""" + +from __future__ import annotations + +import hmac as hmac_mod +import base64 +import hashlib + + +def hmac_sign(secret: bytes, data: bytes) -> str: + """Sign data with HMAC-SHA256, returning a base64-encoded signature.""" + sig = hmac_mod.new(secret, data, hashlib.sha256).digest() + return base64.b64encode(sig).decode("ascii") + + +def hmac_verify(secret: bytes, data: bytes, signature: str) -> bool: + """Verify a base64-encoded HMAC-SHA256 signature. Timing-safe.""" + expected = hmac_sign(secret, data) + return hmac_mod.compare_digest(signature, expected) diff --git a/src/layerlens/attestation/_verify.py b/src/layerlens/attestation/_verify.py new file mode 100644 index 0000000..33b595d --- /dev/null +++ b/src/layerlens/attestation/_verify.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional +from dataclasses import field, dataclass + +from ._hash import compute_hash +from ._signing import hmac_verify +from ._envelope import HashScope, AttestationEnvelope + + +@dataclass +class ChainVerification: + """Result of verifying a hash chain's integrity.""" + + valid: bool + break_index: Optional[int] = None + error: Optional[str] = None + + +@dataclass +class TrialVerification: + """Result of verifying a full trial: chain + root hash + signatures.""" + + valid: bool + chain_valid: bool = True + trial_hash_valid: bool = True + signatures_valid: bool = True + errors: List[str] = field(default_factory=list) + + +@dataclass +class TamperingResult: + """Result of checking whether trace data was modified after hashing.""" + + tampered: bool + modified_indices: List[int] = field(default_factory=list) + chain_broken: bool = False + + +def verify_chain(envelopes: List[AttestationEnvelope]) -> ChainVerification: + """Verify that a hash chain is continuous and unbroken. + + Checks: + - First envelope has previous_hash=None + - Each subsequent envelope's previous_hash matches the prior envelope's hash + """ + if not envelopes: + return ChainVerification(valid=True) + + if envelopes[0].previous_hash is not None: + return ChainVerification( + valid=False, + break_index=0, + error="First envelope must have previous_hash=None", + ) + + for i in range(1, len(envelopes)): + if envelopes[i].previous_hash != envelopes[i - 1].hash: + return ChainVerification( + valid=False, + break_index=i, + error=f"Chain broken at index {i}: " + f"expected previous_hash={envelopes[i - 1].hash!r}, " + f"got {envelopes[i].previous_hash!r}", + ) + + return ChainVerification(valid=True) + + +def verify_trial( + envelopes: List[AttestationEnvelope], + trial_envelope: AttestationEnvelope, + signing_secret: Optional[bytes] = None, +) -> TrialVerification: + """Verify a trial envelope against its event chain. + + Checks chain integrity, trial hash correctness, and (optionally) signatures. + Pass ``signing_secret`` to verify HMAC-SHA256 signatures. + """ + errors: List[str] = [] + + # 1. Chain continuity + chain_result = verify_chain(envelopes) + chain_valid = chain_result.valid + if not chain_valid: + errors.append(f"Chain integrity failed: {chain_result.error}") + + # 2. Trial scope + hash + trial_hash_valid = True + if trial_envelope.scope != HashScope.TRIAL: + trial_hash_valid = False + errors.append(f"Trial envelope has wrong scope: {trial_envelope.scope}") + else: + event_hashes = [e.hash for e in envelopes] + expected_hash = compute_hash({"event_hashes": event_hashes}) + if trial_envelope.hash != expected_hash: + trial_hash_valid = False + errors.append("Trial hash does not match event hashes") + + # 3. Signatures (only if a signing secret is provided) + signatures_valid = True + if signing_secret is not None: + for i, envelope in enumerate(envelopes): + if not envelope.signature: + signatures_valid = False + errors.append(f"Missing signature on event {i}") + else: + if not hmac_verify(signing_secret, envelope.hash.encode("utf-8"), envelope.signature): + signatures_valid = False + errors.append(f"Invalid signature on event {i}") + + if not trial_envelope.signature: + signatures_valid = False + errors.append("Missing signature on trial envelope") + else: + if not hmac_verify(signing_secret, trial_envelope.hash.encode("utf-8"), trial_envelope.signature): + signatures_valid = False + errors.append("Invalid signature on trial envelope") + + valid = chain_valid and trial_hash_valid and signatures_valid + return TrialVerification( + valid=valid, + chain_valid=chain_valid, + trial_hash_valid=trial_hash_valid, + signatures_valid=signatures_valid, + errors=errors, + ) + + +def detect_tampering( + envelopes: List[AttestationEnvelope], + original_data: List[Dict[str, Any]], +) -> TamperingResult: + """Detect which events were modified after being hashed. + + Recomputes the hash for each event (using its stored previous_hash + for chain linkage) and compares against the stored hash. + """ + if len(envelopes) != len(original_data): + return TamperingResult( + tampered=True, + chain_broken=True, + ) + + modified: List[int] = [] + for i, (envelope, data) in enumerate(zip(envelopes, original_data)): + payload = {**data, "_previous_hash": envelope.previous_hash} + recomputed = compute_hash(payload) + if recomputed != envelope.hash: + modified.append(i) + + chain_result = verify_chain(envelopes) + return TamperingResult( + tampered=len(modified) > 0 or not chain_result.valid, + modified_indices=modified, + chain_broken=not chain_result.valid, + ) diff --git a/src/layerlens/instrument/_recorder.py b/src/layerlens/instrument/_recorder.py index dba6a45..9960577 100644 --- a/src/layerlens/instrument/_recorder.py +++ b/src/layerlens/instrument/_recorder.py @@ -1,22 +1,76 @@ from __future__ import annotations -from typing import Any, Optional +import logging +from typing import Any, Dict, List, Optional + +from layerlens.attestation import HashChain from ._types import SpanData from ._upload import upload_trace, async_upload_trace +log: logging.Logger = logging.getLogger(__name__) + + +def _collect_spans(span: SpanData) -> List[Dict[str, Any]]: + """Walk the span tree depth-first and return a flat list of span dicts. + + Uses SpanData.to_dict() to capture every field — structure, inputs, + outputs, metadata, and errors. Children are excluded because we + flatten the tree ourselves; any future SpanData fields are automatically + included in the hash. + """ + result: List[Dict[str, Any]] = [] + span_dict = span.to_dict() + span_dict.pop("children") + result.append(span_dict) + for child in span.children: + result.extend(_collect_spans(child)) + return result + class TraceRecorder: def __init__(self, client: Any) -> None: self._client = client self.root: Optional[SpanData] = None + def _build_attestation(self) -> Dict[str, Any]: + """Build an unsigned hash chain from the span tree. + + The chain provides integrity (tamper-evidence). Signing is + handled server-side at trace ingestion for authenticity. + """ + if self.root is None: + return {} + + chain = HashChain() + spans = _collect_spans(self.root) + for span_dict in spans: + chain.add_event(span_dict) + trial = chain.finalize() + return { + "chain": chain.to_dict(), + "root_hash": trial.hash, + "schema_version": "1.0", + } + def flush(self) -> None: if self.root is None: return - upload_trace(self._client, self.root.to_dict()) + trace_data = self.root.to_dict() + try: + attestation = self._build_attestation() + except Exception as exc: + log.warning("Failed to build attestation chain", exc_info=True) + attestation = {"attestation_error": str(exc)} + upload_trace(self._client, trace_data, attestation) async def async_flush(self) -> None: if self.root is None: return - await async_upload_trace(self._client, self.root.to_dict()) + trace_data = self.root.to_dict() + try: + attestation = self._build_attestation() + except Exception as exc: + log.warning("Failed to build attestation chain", exc_info=True) + attestation = {"attestation_error": str(exc)} + await async_upload_trace(self._client, trace_data, attestation) diff --git a/src/layerlens/instrument/_upload.py b/src/layerlens/instrument/_upload.py index 6597970..020d990 100644 --- a/src/layerlens/instrument/_upload.py +++ b/src/layerlens/instrument/_upload.py @@ -2,18 +2,32 @@ import os import json +import asyncio import logging import tempfile -from typing import Any, Dict +from typing import Any, Dict, Optional log: logging.Logger = logging.getLogger(__name__) -def upload_trace(client: Any, trace_data: Dict[str, Any]) -> None: +def _write_trace_file(payload: Dict[str, Any]) -> str: + """Write trace payload to a temp file and return its path.""" fd, path = tempfile.mkstemp(suffix=".json", prefix="layerlens_trace_") + with os.fdopen(fd, "w") as f: + json.dump([payload], f, default=str) + return path + + +def upload_trace( + client: Any, + trace_data: Dict[str, Any], + attestation: Optional[Dict[str, Any]] = None, +) -> None: + payload = trace_data + if attestation: + payload = {**trace_data, "attestation": attestation} + path = _write_trace_file(payload) try: - with os.fdopen(fd, "w") as f: - json.dump([trace_data], f, default=str) client.traces.upload(path) finally: try: @@ -22,11 +36,16 @@ def upload_trace(client: Any, trace_data: Dict[str, Any]) -> None: log.debug("Failed to remove temp trace file: %s", path) -async def async_upload_trace(client: Any, trace_data: Dict[str, Any]) -> None: - fd, path = tempfile.mkstemp(suffix=".json", prefix="layerlens_trace_") +async def async_upload_trace( + client: Any, + trace_data: Dict[str, Any], + attestation: Optional[Dict[str, Any]] = None, +) -> None: + payload = trace_data + if attestation: + payload = {**trace_data, "attestation": attestation} + path = await asyncio.to_thread(_write_trace_file, payload) try: - with os.fdopen(fd, "w") as f: - json.dump([trace_data], f, default=str) await client.traces.upload(path) finally: try: diff --git a/tests/attestation/__init__.py b/tests/attestation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/attestation/test_chain.py b/tests/attestation/test_chain.py new file mode 100644 index 0000000..b204578 --- /dev/null +++ b/tests/attestation/test_chain.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import pytest + +from layerlens.attestation._chain import HashChain +from layerlens.attestation._envelope import HashScope + + +class TestHashChainBuilding: + def test_single_event(self): + chain = HashChain() + env = chain.add_event({"name": "span-1"}) + assert env.previous_hash is None + assert env.scope == HashScope.EVENT + assert env.hash.startswith("sha256:") + + def test_chain_linking(self): + """Each event links to the previous hash.""" + chain = HashChain() + e1 = chain.add_event({"name": "span-1"}) + e2 = chain.add_event({"name": "span-2"}) + e3 = chain.add_event({"name": "span-3"}) + + assert e1.previous_hash is None + assert e2.previous_hash == e1.hash + assert e3.previous_hash == e2.hash + + def test_different_data_different_hashes(self): + chain = HashChain() + e1 = chain.add_event({"name": "a"}) + e2 = chain.add_event({"name": "b"}) + assert e1.hash != e2.hash + + def test_envelopes_property(self): + chain = HashChain() + chain.add_event({"name": "span-1"}) + chain.add_event({"name": "span-2"}) + assert len(chain.envelopes) == 2 + + +class TestHashChainFinalization: + def test_finalize_produces_trial_scope(self): + chain = HashChain() + chain.add_event({"name": "span-1"}) + trial = chain.finalize() + assert trial.scope == HashScope.TRIAL + + def test_finalize_root_hash_deterministic(self): + """Same events in same order produce the same root hash.""" + + def build(): + c = HashChain() + c.add_event({"name": "a"}) + c.add_event({"name": "b"}) + return c.finalize() + + assert build().hash == build().hash + + def test_finalize_seals_chain(self): + """No events can be added after finalization.""" + chain = HashChain() + chain.add_event({"name": "span-1"}) + chain.finalize() + with pytest.raises(RuntimeError, match="terminated"): + chain.add_event({"name": "span-2"}) + + def test_finalize_empty_chain_raises(self): + chain = HashChain() + with pytest.raises(RuntimeError, match="empty"): + chain.finalize() + + def test_finalize_links_to_last_event(self): + chain = HashChain() + chain.add_event({"name": "a"}) + last = chain.add_event({"name": "b"}) + trial = chain.finalize() + assert trial.previous_hash == last.hash + + +class TestHashChainTermination: + def test_terminate_blocks_add(self): + chain = HashChain() + chain.add_event({"name": "span-1"}) + chain.terminate("policy_violation") + with pytest.raises(RuntimeError, match="terminated"): + chain.add_event({"name": "span-2"}) + + def test_terminate_blocks_finalize(self): + chain = HashChain() + chain.add_event({"name": "span-1"}) + chain.terminate("policy_violation") + with pytest.raises(RuntimeError, match="non-attestable"): + chain.finalize() + + def test_is_terminated_flag(self): + chain = HashChain() + assert not chain.is_terminated + chain.terminate("test") + assert chain.is_terminated + + def test_terminate_reason_in_error(self): + chain = HashChain() + chain.terminate("safety_check_failed") + with pytest.raises(RuntimeError, match="safety_check_failed"): + chain.add_event({"name": "span-1"}) + + +class TestHashChainSerialization: + def test_to_dict(self): + chain = HashChain() + chain.add_event({"name": "span-1"}) + d = chain.to_dict() + assert "events" in d + assert len(d["events"]) == 1 + assert d["events"][0]["scope"] == "event" + assert d["events"][0]["hash"].startswith("sha256:") + + def test_to_dict_finalized_is_clean(self): + """Normal finalization should not include termination details.""" + chain = HashChain() + chain.add_event({"name": "span-1"}) + chain.finalize() + d = chain.to_dict() + assert "terminated_reason" not in d + + def test_to_dict_terminated_includes_reason(self): + """Policy violation termination should include the reason.""" + chain = HashChain() + chain.add_event({"name": "span-1"}) + chain.terminate("policy_violation") + d = chain.to_dict() + assert d["terminated_reason"] == "policy_violation" diff --git a/tests/attestation/test_hash.py b/tests/attestation/test_hash.py new file mode 100644 index 0000000..203c6e6 --- /dev/null +++ b/tests/attestation/test_hash.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import re +from enum import Enum +from datetime import datetime, timezone + +from layerlens.attestation._hash import compute_hash, canonical_json + + +class TestCanonicalJson: + def test_sorted_keys(self): + """Key order must not affect output.""" + a = canonical_json({"b": 2, "a": 1}) + b = canonical_json({"a": 1, "b": 2}) + assert a == b + + def test_compact_format(self): + """No whitespace in output.""" + result = canonical_json({"a": 1, "b": [2, 3]}) + assert " " not in result + assert result == '{"a":1,"b":[2,3]}' + + def test_nested_structures(self): + """Nested dicts and lists are handled deterministically.""" + data = {"z": {"y": 1, "x": 2}, "a": [3, 2, 1]} + result = canonical_json(data) + assert result == '{"a":[3,2,1],"z":{"x":2,"y":1}}' + + def test_datetime_serialization(self): + dt = datetime(2026, 3, 23, 12, 0, 0, tzinfo=timezone.utc) + result = canonical_json({"ts": dt}) + assert "2026-03-23" in result + + def test_enum_serialization(self): + class Color(Enum): + RED = "red" + + result = canonical_json({"color": Color.RED}) + assert '"red"' in result + + +class TestComputeHash: + def test_format(self): + """Hash must be 'sha256:' followed by 64 hex chars.""" + h = compute_hash({"test": "data"}) + assert re.match(r"^sha256:[0-9a-f]{64}$", h) + + def test_deterministic(self): + """Same data always produces the same hash.""" + data = {"key": "value", "num": 42} + assert compute_hash(data) == compute_hash(data) + + def test_key_order_irrelevant(self): + """Different key orders produce the same hash.""" + assert compute_hash({"b": 2, "a": 1}) == compute_hash({"a": 1, "b": 2}) + + def test_different_data_different_hash(self): + assert compute_hash({"a": 1}) != compute_hash({"a": 2}) + + def test_empty_dict(self): + h = compute_hash({}) + assert re.match(r"^sha256:[0-9a-f]{64}$", h) + + def test_cross_language_vector(self): + """Pinned vector shared with Go backend (TestComputeCanonicalHash_CrossLanguageVector). + + If this test fails, Python and Go will produce different root hashes + for the same trace, breaking attestation verification. + """ + h = compute_hash({"event_hashes": ["sha256:aaa", "sha256:bbb"]}) + assert h == "sha256:b930d0a2cbda5171b8a12d17445c38b8c0842344f2d691a00d24b3359a854db5" diff --git a/tests/attestation/test_integration.py b/tests/attestation/test_integration.py new file mode 100644 index 0000000..5f8c436 --- /dev/null +++ b/tests/attestation/test_integration.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import json +from unittest.mock import Mock + +from layerlens.instrument import span, trace +from layerlens.attestation import verify_chain, detect_tampering +from layerlens.attestation._envelope import HashScope, AttestationEnvelope + + +def _make_client(): + """Create a mock client that captures the uploaded trace JSON.""" + client = Mock() + client.traces = Mock() + uploaded = {} + + def capture(path): + with open(path) as f: + uploaded["data"] = json.load(f) + + client.traces.upload = Mock(side_effect=capture) + return client, uploaded + + +class TestTraceAttestation: + def test_trace_includes_attestation(self): + """@trace should include attestation data in the upload.""" + client, uploaded = _make_client() + + @trace(client) + def my_agent(query: str): + return f"answer to {query}" + + my_agent("hello") + + payload = uploaded["data"][0] + assert "attestation" in payload + att = payload["attestation"] + assert "chain" in att + assert "root_hash" in att + assert att["root_hash"].startswith("sha256:") + assert att["schema_version"] == "1.0" + + def test_trace_with_child_spans(self): + """Attestation chain should include all spans in the tree.""" + client, uploaded = _make_client() + + @trace(client) + def my_agent(query: str): + with span("step-1", kind="tool") as s: + s.output = "result-1" + with span("step-2", kind="llm") as s: + s.output = "result-2" + return "done" + + my_agent("test") + + att = uploaded["data"][0]["attestation"] + chain_events = att["chain"]["events"] + # Root span + 2 child spans = 3 events in the chain + assert len(chain_events) == 3 + + def test_chain_events_are_linked(self): + """Verify the chain in the uploaded payload is valid.""" + client, uploaded = _make_client() + + @trace(client) + def my_agent(query: str): + with span("step-1") as s: + s.output = "r1" + with span("step-2") as s: + s.output = "r2" + return "done" + + my_agent("test") + + chain_events = uploaded["data"][0]["attestation"]["chain"]["events"] + # Reconstruct envelopes and verify chain integrity + envelopes = [ + AttestationEnvelope( + hash=e["hash"], + scope=HashScope(e["scope"]), + previous_hash=e["previous_hash"], + ) + for e in chain_events + ] + result = verify_chain(envelopes) + assert result.valid + + def test_trace_error_still_has_attestation(self): + """Even when the traced function raises, attestation should be present.""" + client, uploaded = _make_client() + + @trace(client) + def failing_agent(): + with span("step-1") as s: + s.output = "ok" + raise ValueError("boom") + + try: + failing_agent() + except ValueError: + pass + + payload = uploaded["data"][0] + assert "attestation" in payload + assert payload["attestation"]["root_hash"].startswith("sha256:") + + def test_modifying_output_breaks_chain(self): + """Changing what the agent said must invalidate the attestation.""" + client, uploaded = _make_client() + + @trace(client) + def my_agent(query: str): + with span("llm-call", kind="llm") as s: + s.output = "the real answer" + return "done" + + my_agent("test") + + att = uploaded["data"][0]["attestation"] + envelopes = [ + AttestationEnvelope( + hash=e["hash"], + scope=HashScope(e["scope"]), + previous_hash=e["previous_hash"], + ) + for e in att["chain"]["events"] + ] + + # Build the original span dicts that were hashed (root + child) + payload = uploaded["data"][0] + original_spans = [] + for s in [payload] + payload.get("children", []): + d = {k: v for k, v in s.items() if k not in ("children", "attestation")} + original_spans.append(d) + + # Verify clean data passes + clean = detect_tampering(envelopes, original_spans) + assert not clean.tampered + + # Tamper: change the LLM output + tampered_spans = [dict(d) for d in original_spans] + tampered_spans[1] = {**tampered_spans[1], "output": "a forged answer"} + + tampered = detect_tampering(envelopes, tampered_spans) + assert tampered.tampered + assert 1 in tampered.modified_indices diff --git a/tests/attestation/test_signing.py b/tests/attestation/test_signing.py new file mode 100644 index 0000000..5a2b8f7 --- /dev/null +++ b/tests/attestation/test_signing.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from layerlens.attestation import ( + HashChain, + hmac_sign, + hmac_verify, + verify_trial, +) + + +class TestHMACSigning: + def test_sign_produces_base64(self): + sig = hmac_sign(b"test-key", b"sha256:" + b"a" * 64) + assert sig # non-empty string + assert isinstance(sig, str) + + def test_sign_deterministic(self): + data = b"sha256:" + b"a" * 64 + assert hmac_sign(b"test-key", data) == hmac_sign(b"test-key", data) + + def test_different_data_different_signatures(self): + s1 = hmac_sign(b"test-key", b"sha256:" + b"a" * 64) + s2 = hmac_sign(b"test-key", b"sha256:" + b"b" * 64) + assert s1 != s2 + + def test_different_keys_different_signatures(self): + data = b"sha256:" + b"a" * 64 + assert hmac_sign(b"key-1", data) != hmac_sign(b"key-2", data) + + def test_verify_valid(self): + data = b"sha256:" + b"a" * 64 + sig = hmac_sign(b"test-key", data) + assert hmac_verify(b"test-key", data, sig) + + def test_verify_invalid(self): + assert not hmac_verify(b"test-key", b"sha256:" + b"a" * 64, "bogus") + + def test_verify_wrong_data(self): + sig = hmac_sign(b"test-key", b"sha256:" + b"a" * 64) + assert not hmac_verify(b"test-key", b"sha256:" + b"b" * 64, sig) + + def test_verify_wrong_key(self): + sig = hmac_sign(b"key-1", b"data") + assert not hmac_verify(b"key-2", b"data", sig) + + +class TestUnsignedChainHasNoSignatures: + def test_unsigned_chain_has_no_signatures(self): + chain = HashChain() + e1 = chain.add_event({"name": "span-1"}) + trial = chain.finalize() + + assert e1.signature is None + assert e1.signing_key_id is None + assert trial.signature is None + + def test_to_dict_omits_signature_when_unsigned(self): + chain = HashChain() + chain.add_event({"name": "span-1"}) + d = chain.to_dict() + + event = d["events"][0] + assert "signature" not in event + assert "signing_key_id" not in event + + +class TestVerifyTrialWithSigning: + """Verify that verify_trial still works with externally-signed envelopes. + + In the server-side signing model, the backend signs the chain after + ingestion. These tests simulate that by manually signing envelopes + and verifying them with verify_trial(). + """ + + def _build_and_sign(self, secret: bytes, key_id: str = "org-123"): + """Build an unsigned chain, then manually sign each envelope.""" + chain = HashChain() + chain.add_event({"name": "a"}) + chain.add_event({"name": "b"}) + envelopes = chain.envelopes + trial = chain.finalize() + + # Simulate server-side signing + for env in envelopes: + env.signature = hmac_sign(secret, env.hash.encode("utf-8")) + env.signing_key_id = key_id + trial.signature = hmac_sign(secret, trial.hash.encode("utf-8")) + trial.signing_key_id = key_id + + return envelopes, trial + + def test_valid_signed_trial(self): + secret = b"test-key" + envelopes, trial = self._build_and_sign(secret) + + result = verify_trial(envelopes, trial, signing_secret=secret) + assert result.valid + assert result.chain_valid + assert result.trial_hash_valid + assert result.signatures_valid + assert result.errors == [] + + def test_tampered_signature_detected(self): + secret = b"test-key" + envelopes, trial = self._build_and_sign(secret) + + # Tamper with the event signature + envelopes[0].signature = "dGFtcGVyZWQ=" # base64("tampered") + + result = verify_trial(envelopes, trial, signing_secret=secret) + assert not result.valid + assert not result.signatures_valid + assert result.chain_valid + assert result.trial_hash_valid + + def test_wrong_key_rejects(self): + envelopes, trial = self._build_and_sign(b"key-1") + + result = verify_trial(envelopes, trial, signing_secret=b"key-2") + assert not result.valid + assert not result.signatures_valid + + def test_unsigned_chain_passes_without_secret(self): + """verify_trial without signing_secret ignores missing signatures.""" + chain = HashChain() + chain.add_event({"name": "a"}) + envelopes = chain.envelopes + trial = chain.finalize() + + result = verify_trial(envelopes, trial) + assert result.valid + assert result.signatures_valid # vacuously true + + def test_stripped_signatures_detected(self): + """When signing_secret is provided, missing signatures should fail.""" + secret = b"test-key" + envelopes, trial = self._build_and_sign(secret) + + # Strip signatures + envelopes[0].signature = None + trial.signature = None + + result = verify_trial(envelopes, trial, signing_secret=secret) + assert not result.valid + assert not result.signatures_valid + assert any("Missing signature" in e for e in result.errors) + + def test_single_event_signed_chain(self): + """Signed chain with exactly one event works correctly.""" + secret = b"test-key" + chain = HashChain() + chain.add_event({"name": "only"}) + envelopes = chain.envelopes + trial = chain.finalize() + + # Manually sign + for env in envelopes: + env.signature = hmac_sign(secret, env.hash.encode("utf-8")) + env.signing_key_id = "org-1" + trial.signature = hmac_sign(secret, trial.hash.encode("utf-8")) + trial.signing_key_id = "org-1" + + assert len(envelopes) == 1 + assert envelopes[0].signature is not None + + result = verify_trial(envelopes, trial, signing_secret=secret) + assert result.valid + assert result.signatures_valid diff --git a/tests/attestation/test_verify.py b/tests/attestation/test_verify.py new file mode 100644 index 0000000..b2f34c8 --- /dev/null +++ b/tests/attestation/test_verify.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from layerlens.attestation import ( + HashChain, + HashScope, + verify_chain, + verify_trial, + detect_tampering, +) + + +class TestVerifyChain: + def test_valid_chain(self): + chain = HashChain() + chain.add_event({"name": "a"}) + chain.add_event({"name": "b"}) + chain.add_event({"name": "c"}) + result = verify_chain(chain.envelopes) + assert result.valid + assert result.break_index is None + + def test_empty_chain_valid(self): + result = verify_chain([]) + assert result.valid + + def test_single_event_valid(self): + chain = HashChain() + chain.add_event({"name": "a"}) + result = verify_chain(chain.envelopes) + assert result.valid + + def test_broken_first_link(self): + """First envelope must have previous_hash=None.""" + chain = HashChain() + chain.add_event({"name": "a"}) + envelopes = chain.envelopes + # Tamper: set previous_hash on first event + envelopes[0].previous_hash = "sha256:fake" + result = verify_chain(envelopes) + assert not result.valid + assert result.break_index == 0 + + def test_broken_middle_link(self): + chain = HashChain() + chain.add_event({"name": "a"}) + chain.add_event({"name": "b"}) + chain.add_event({"name": "c"}) + envelopes = chain.envelopes + # Tamper: break the link between event 1 and 2 + envelopes[2].previous_hash = "sha256:fake" + result = verify_chain(envelopes) + assert not result.valid + assert result.break_index == 2 + + +class TestVerifyTrial: + def test_valid_trial(self): + chain = HashChain() + chain.add_event({"name": "a"}) + chain.add_event({"name": "b"}) + envelopes = chain.envelopes + trial = chain.finalize() + result = verify_trial(envelopes, trial) + assert result.valid + + def test_wrong_scope_rejected(self): + chain = HashChain() + chain.add_event({"name": "a"}) + envelopes = chain.envelopes + trial = chain.finalize() + trial.scope = HashScope.EVENT # Wrong scope + result = verify_trial(envelopes, trial) + assert not result.valid + assert not result.trial_hash_valid + assert any("scope" in e for e in result.errors) + + def test_tampered_trial_hash(self): + chain = HashChain() + chain.add_event({"name": "a"}) + envelopes = chain.envelopes + trial = chain.finalize() + trial.hash = "sha256:" + "0" * 64 # Wrong hash + result = verify_trial(envelopes, trial) + assert not result.valid + assert not result.trial_hash_valid + assert any("does not match" in e for e in result.errors) + + +class TestDetectTampering: + def test_no_tampering(self): + data = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + chain = HashChain() + for d in data: + chain.add_event(d) + result = detect_tampering(chain.envelopes, data) + assert not result.tampered + assert result.modified_indices == [] + assert not result.chain_broken + + def test_detect_modified_event(self): + data = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + chain = HashChain() + for d in data: + chain.add_event(d) + # Tamper with the second event's data + tampered_data = [{"name": "a"}, {"name": "CHANGED"}, {"name": "c"}] + result = detect_tampering(chain.envelopes, tampered_data) + assert result.tampered + assert 1 in result.modified_indices + + def test_detect_multiple_modifications(self): + data = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + chain = HashChain() + for d in data: + chain.add_event(d) + tampered = [{"name": "X"}, {"name": "b"}, {"name": "Z"}] + result = detect_tampering(chain.envelopes, tampered) + assert result.tampered + assert 0 in result.modified_indices + assert 2 in result.modified_indices + + def test_detect_count_mismatch(self): + data = [{"name": "a"}, {"name": "b"}] + chain = HashChain() + for d in data: + chain.add_event(d) + result = detect_tampering(chain.envelopes, [{"name": "a"}]) + assert result.tampered + assert result.chain_broken + + def test_detect_tampering_with_multi_event_chain(self): + """detect_tampering works correctly on multi-event chains.""" + data = [{"name": "a"}, {"name": "b"}, {"name": "c"}] + chain = HashChain() + for d in data: + chain.add_event(d) + + # No tampering — should pass + result = detect_tampering(chain.envelopes, data) + assert not result.tampered + assert result.modified_indices == [] + + # Tamper with one event + tampered = [{"name": "a"}, {"name": "CHANGED"}, {"name": "c"}] + result = detect_tampering(chain.envelopes, tampered) + assert result.tampered + assert 1 in result.modified_indices