diff --git a/docs/adapters/memory-contract.md b/docs/adapters/memory-contract.md new file mode 100644 index 0000000..384b826 --- /dev/null +++ b/docs/adapters/memory-contract.md @@ -0,0 +1,194 @@ +# Memory persistence contract for adapters + +LayerLens adapters carry per-conversation, per-agent recall — episodic, +procedural, and semantic memory — alongside the trace events they emit. +This module ports the ad-hoc memory plumbing that the four mature +framework adapters (LangChain, AutoGen, CrewAI, Semantic Kernel) carry +on the `ateam` monorepo into a shared, replay-safe primitive that any +adapter on the `stratix-python` SDK can plug into. Without this +plumbing, lighter adapters behave as "goldfish agents" — every run +starts from a blank slate, which is the difference between a usable +production agent and a demo. + +This document defines the binding contract every adapter that integrates +the recorder must satisfy. It is enforced at runtime by +`MemoryRecorder.__init__` (fail-fast on missing tenant), by +`MemoryRecorder.restore` (cross-tenant rejection + content-hash +integrity check), and at CI time by +`tests/instrument/adapters/_base/test_memory.py`. + +Cross-pollination audit reference: +[`A:/tmp/adapter-cross-pollination-audit.md`](../../../tmp/adapter-cross-pollination-audit.md) +§2 item #1. + +## The three buckets + +The memory model is the canonical agent-memory split that appears +across the literature (LangChain memory module; CrewAI procedural +memory; AutoGen episodic/semantic split): + +| Bucket | Lifetime | Bounded by | Eviction | +|----------------|-----------------------|------------------------|------------------------| +| **Episodic** | per-turn | `max_episodic` (200) | FIFO (oldest dropped) | +| **Procedural** | recurring patterns | `max_procedural` (16) | least-frequent + ties broken by oldest `last_seen_turn` | +| **Semantic** | long-lived facts | `max_semantic` (64) | least-recently-set | + +* **Episodic** — per-turn `(input, output, error?, tools?, extra?)` + records, ordered by `turn_index`. The detector for procedural + patterns reads this stream. +* **Procedural** — derived: each entry is + `{"pattern": [[prev_turn_tools], [cur_turn_tools]], "count": int, + "last_seen_turn": int}`. Detected automatically from the recent + episodic window every time `record_turn` is called. +* **Semantic** — caller-controlled key/value store + (`set_semantic(key, value)`). Both keys and values are stringified. + Callers wanting structured semantic data should JSON-encode their + value. + +All three buckets are **bounded** (CLAUDE.md "every cache must be +bounded"). The defaults are conservative; callers wanting a different +size construct the recorder with explicit `max_*` kwargs. + +## The contract + +1. **Every adapter owns exactly one `MemoryRecorder`.** Constructed in + `BaseAdapter.__init__` and exposed via `adapter.memory_recorder` + (read-only property). The recorder is bound to the same `org_id` as + the adapter — multi-tenancy is propagated. + +2. **Construction without a tenant raises.** `MemoryRecorder(org_id="")` + raises `ValueError`. There is no "default" tenant, no blank + fallback. `BaseAdapter.__init__` already fails fast on missing + `org_id` (see [multi-tenancy.md](multi-tenancy.md)) so the recorder + inherits that guarantee. + +3. **Every recorded turn is bounded.** A single oversized turn cannot + blow past the bucket caps: per-field strings longer than 8 KB are + truncated to a deterministic suffix (`<...truncated:orig_len=N>`). + The truncation is defence-in-depth, not a substitute for the + adapter-level truncation policy (cross-poll #3). Detection of + recurring tool patterns runs in O(window) per turn. + +4. **Cross-tenant restore is prohibited.** `restore(snapshot)` raises + `ValueError` if `snapshot.org_id != recorder.org_id`. This mirrors + the `BaseAdapter.org_id` contract — a tenant-A snapshot cannot land + in a tenant-B recorder, even if both happen to share a process. + +5. **Snapshots are tamper-evident.** `restore(snapshot)` recomputes the + SHA-256 content hash and rejects the snapshot if the recorded hash + does not match. Guards against accidentally-mutated dicts in + transit and against forged snapshots reconstructed without the + `MemorySnapshot` factory. + +6. **Snapshots are replay-safe.** The round-trip + `snapshot() → restore() → snapshot()` produces a snapshot with the + identical `content_hash` (deterministic reconstruction). This is the + foundation of replay-safe memory: the replay engine restores the + recorder, then the adapter re-runs the agent and produces the same + next-turn snapshot. The `record_turn` method stamps a wall-clock + `timestamp_ns` into the new turn — replay engines suppress this + drift by capturing the original `timestamp_ns` from the source + trace and seeding the recorder's clock at restore time. + +7. **Snapshot serialisation is dict-shaped.** `snapshot.to_dict()` + returns a JSON-serialisable dict; `MemorySnapshot.from_dict(data)` + round-trips. Adapters embed the dict under + `ReplayableTrace.metadata["memory_snapshot"]` so the replay engine + can reconstruct via `MemoryRecorder.restore(MemorySnapshot.from_dict(...))`. + +8. **Recording is best-effort.** `BaseAdapter.record_memory_turn` + catches and logs all exceptions at DEBUG. A failure inside the + recorder MUST NOT propagate into the host framework's call stack — + tracing never breaks user code (CLAUDE.md). The trade-off is that a + recorder bug shows up as missing memory in the replay rather than a + crash in production. + +9. **Thread-safe.** `record_turn`, `set_semantic`, `clear`, and + `restore` are all guarded by an internal lock. Many concurrent + `record_turn` calls produce a snapshot whose `episodic` indices + form an unbroken `1..N` sequence. + +## Wiring at the lifecycle hook + +Every adapter wires `record_memory_turn(...)` into its **agent-output +boundary** — the point at which the framework reports a completed +agent step / chat turn / invocation. The exact hook varies by +framework: + +| Adapter | Hook | Episodic input | Tool list | +|---------------------|---------------------------------------------------------------------|---------------------------|--------------------------------------------------------------| +| `agno` | `Agent.run` / `arun` finally-block | `args[0] / kwargs["input"]`| `_collect_tool_names(agent, result)` from `result.messages` | +| `ms_agent_framework`| `Chat.invoke` / `invoke_stream` finally-block | `kwargs["input"]/["message"]`| `_collect_tool_names_from_messages(seen)` from streamed items| +| `openai_agents` | `_on_agent_span_end` (TraceProcessor) + `on_run_end` (Runner wrap) | cached at `_on_agent_span_start` per `span_id` | rolled up from `_on_function_span_end` per `parent_id` | +| `llama_index` | `_on_agent_step_end` | cached at `_on_agent_step_start` per thread id| rolled up from `_on_tool_call` per thread id | +| `google_adk` | `after_agent_callback` + `on_agent_end` | cached at `before_agent_callback` per thread id| rolled up from `after_tool_callback` per thread id | +| `bedrock_agents` | `_after_invoke_agent` (boto3 hook) | cached at `_before_invoke_agent` per thread id| rolled up from `_process_trace` action-group / KB step names | + +Each adapter also embeds its memory snapshot in `serialize_for_replay` +output via `ReplayableTrace.metadata["memory_snapshot"] = +self.memory_snapshot_dict()` — so a downstream replay engine can +reconstruct the full episodic + procedural + semantic state before +re-execution. + +## Honest scope disclosure (target adapter coverage) + +The cross-pollination audit §2 item #1 enumerates **seven** target +adapters: `agno`, `ms_agent_framework`, `openai_agents`, `llama_index`, +`google_adk`, `bedrock_agents`, **`browser_use`**. + +Six are wired in this PR. The seventh — `browser_use` — does not exist +on this branch's base (`feat/instrument-multitenancy-org-id-propagation`); +it lives on the parallel `feat/instrument-frameworks-browser-use-full` +history. It will be wired when that adapter is ported to this base or +when the histories merge. This follows the same honest-disclosure +pattern as PR #120 (state filters, which omitted `ms_agent_framework` +for the same reason — adapter not on its base). + +For `browser_use`, the eventual wiring (per the cross-pollination +audit) will be: + +* **Episodic** — page navigation events (`url`, `action`, `selector`) + per turn. +* **Procedural** — recurring `(prev_action, current_action)` patterns + (e.g. `"click[search]"` → `"type[query]"` → `"click[submit]"`). +* **Semantic** — long-lived page-content cache keyed by URL or DOM + hash, so a re-visit can short-circuit page reload during replay. + +## Audit hooks + +* **Construction failures** — `MemoryRecorder.__init__` raises with a + message naming the missing field (`"non-empty org_id"`, + `"bounded buffer sizes"`). +* **Cross-tenant restore** — raises with the explicit + `"Cross-tenant restore is prohibited (CLAUDE.md multi-tenancy)"` + message. +* **Tampered snapshots** — raises with + `"snapshot content_hash mismatch"` and includes the recorded vs + recomputed hashes for triage. +* **Best-effort recording failures** — logged at DEBUG via + `BaseAdapter.record_memory_turn` with `exc_info=True` so the failing + call site is preserved without escalating. + +## Replay engine integration + +A replay flow looks like: + +```python +# Original run captures both events and memory. +adapter = AgnoAdapter(stratix=client, org_id="tenant-A") +# ... agent runs, on_run_end fires record_memory_turn() ... +trace = adapter.serialize_for_replay() +trace.metadata["memory_snapshot"] # serialised MemorySnapshot dict. + +# Replay reconstructs the recorder before re-execution. +replay_adapter = AgnoAdapter(stratix=client, org_id="tenant-A") +snapshot = MemorySnapshot.from_dict(trace.metadata["memory_snapshot"]) +replay_adapter.memory_recorder.restore(snapshot) +# Re-run the agent — it sees the original recall state. +``` + +The next-turn snapshot taken from `replay_adapter` will match the +original (modulo the wall-clock `timestamp_ns` of the new turn — see +contract item 6). This is what makes memory persistence "replay-safe": +the replay engine can drive an adapter through the same agent state +the original run reached. diff --git a/src/layerlens/instrument/adapters/_base/__init__.py b/src/layerlens/instrument/adapters/_base/__init__.py index daacfcc..a254abe 100644 --- a/src/layerlens/instrument/adapters/_base/__init__.py +++ b/src/layerlens/instrument/adapters/_base/__init__.py @@ -13,6 +13,13 @@ TraceStoreSink, IngestionPipelineSink, ) +from layerlens.instrument.adapters._base.memory import ( + DEFAULT_MAX_EPISODIC, + DEFAULT_MAX_SEMANTIC, + DEFAULT_MAX_PROCEDURAL, + MemoryRecorder, + MemorySnapshot, +) from layerlens.instrument.adapters._base.adapter import ( ORG_ID_FIELD, AdapterInfo, @@ -41,8 +48,13 @@ "AdapterStatus", "BaseAdapter", "CaptureConfig", + "DEFAULT_MAX_EPISODIC", + "DEFAULT_MAX_PROCEDURAL", + "DEFAULT_MAX_SEMANTIC", "EventSink", "IngestionPipelineSink", + "MemoryRecorder", + "MemorySnapshot", "ORG_ID_FIELD", "PydanticCompat", "ReplayableTrace", diff --git a/src/layerlens/instrument/adapters/_base/adapter.py b/src/layerlens/instrument/adapters/_base/adapter.py index b20ccb9..803b546 100644 --- a/src/layerlens/instrument/adapters/_base/adapter.py +++ b/src/layerlens/instrument/adapters/_base/adapter.py @@ -28,6 +28,7 @@ from layerlens.instrument.adapters._base.sinks import EventSink from layerlens._compat.pydantic import Field, BaseModel, model_dump +from layerlens.instrument.adapters._base.memory import MemoryRecorder, MemorySnapshot from layerlens.instrument.adapters._base.capture import ( ALWAYS_ENABLED_EVENT_TYPES, CaptureConfig, @@ -292,6 +293,18 @@ def __init__( # the public API and may change in v2. self._event_sinks: List["EventSink"] = list(event_sinks) if event_sinks else [] + # Cross-poll #1: per-adapter memory recorder. Bound to the same + # tenant as the adapter — :class:`MemoryRecorder` enforces the + # match on every :meth:`MemoryRecorder.restore`. Adapters call + # :meth:`record_memory_turn` after each ``agent.output``-style + # event so cross-conversation recall (episodic / procedural / + # semantic) lives alongside the trace events. The recorder is + # included in :meth:`serialize_for_replay` output via + # :meth:`memory_snapshot_dict` so a replay engine can deterministically + # reconstruct the agent's memory state before re-execution. See + # ``docs/adapters/memory-contract.md`` for the contract. + self._memory_recorder: MemoryRecorder = MemoryRecorder(org_id=self._org_id) + # --- Sink management (public API) --- def add_sink(self, sink: "EventSink") -> None: @@ -397,6 +410,88 @@ def info(self) -> AdapterInfo: base_info = base_info.copy(update={"requires_pydantic": self.requires_pydantic}) return base_info + # --- Memory persistence (cross-poll #1) ------------------------ + + @property + def memory_recorder(self) -> MemoryRecorder: + """The adapter's bound :class:`MemoryRecorder`. + + Constructed in :meth:`__init__` and tenant-scoped to the same + ``org_id`` as the adapter. Adapters wire :meth:`record_memory_turn` + into their per-turn lifecycle hooks (typically alongside the + ``agent.output`` emission). + """ + return self._memory_recorder + + def record_memory_turn( + self, + *, + agent_name: Optional[str] = None, + input_data: Any = None, + output_data: Any = None, + error: Optional[str] = None, + tools: Optional[List[str]] = None, + extra: Optional[Dict[str, Any]] = None, + ) -> None: + """Record a completed turn in the per-adapter memory recorder. + + Adapters call this from their ``on_run_end`` / ``after_agent_callback`` + / equivalent per-turn lifecycle hook *after* emitting the + ``agent.output`` event. The recorder is thread-safe and bounded + (see :class:`MemoryRecorder` defaults), so repeated calls under + concurrent execution are safe and never grow without bound. + + Failures are swallowed and logged at DEBUG: memory persistence is + a best-effort cross-cutting concern and must not break the host + framework's run path. CLAUDE.md "tracing never breaks user code". + + Args: + agent_name: The agent that produced this turn. + input_data: Input the agent received. + output_data: Output the agent produced. + error: Optional error string if the turn failed. + tools: Optional list of tool names invoked during the turn. + extra: Optional adapter-specific metadata. Keys are sorted + inside the recorder for hash determinism. + """ + try: + self._memory_recorder.record_turn( + agent_name=agent_name, + input_data=input_data, + output_data=output_data, + error=error, + tools=tools, + extra=extra, + ) + except Exception: + # Memory persistence is best-effort. A failure here must + # never propagate into the host framework's call stack. + logger.debug( + "MemoryRecorder.record_turn failed for adapter %s", + self.FRAMEWORK, + exc_info=True, + ) + + def memory_snapshot(self) -> MemorySnapshot: + """Return the current immutable :class:`MemorySnapshot`. + + Convenience wrapper for :meth:`MemoryRecorder.snapshot` so + callers do not need to reach into the private recorder. + """ + return self._memory_recorder.snapshot() + + def memory_snapshot_dict(self) -> Dict[str, Any]: + """Return the memory snapshot as a JSON-serialisable dict. + + Adapters embed this in :meth:`serialize_for_replay`'s + :class:`ReplayableTrace` ``metadata`` under the + ``"memory_snapshot"`` key so the replay engine can reconstruct + the recorder's state via :meth:`MemoryRecorder.restore` before + re-executing the agent. The shape is the public, content-addressable + contract documented in ``docs/adapters/memory-contract.md``. + """ + return self._memory_recorder.snapshot().to_dict() + @abstractmethod def serialize_for_replay(self) -> ReplayableTrace: """Serialize the current trace data for replay.""" diff --git a/src/layerlens/instrument/adapters/_base/memory.py b/src/layerlens/instrument/adapters/_base/memory.py new file mode 100644 index 0000000..8c0642e --- /dev/null +++ b/src/layerlens/instrument/adapters/_base/memory.py @@ -0,0 +1,683 @@ +"""Shared cross-adapter memory persistence module. + +This module ports the **episodic / procedural / semantic** memory +pattern carried ad-hoc by the four mature framework adapters +(LangChain, AutoGen, CrewAI, Semantic Kernel) into a shared, +replay-safe primitive that any framework adapter can plug in to +deliver cross-conversation recall. + +Background — Cross-pollination audit §2.1 +========================================= + +The audit at ``A:/tmp/adapter-cross-pollination-audit.md`` identified +memory persistence as the **highest-value cross-cutting lift** from +the mature adapters to the seven lighter adapters that lack it +(``agno``, ``ms_agent_framework``, ``openai_agents``, ``llama_index``, +``google_adk``, ``bedrock_agents``, ``browser_use``). Without this +plumbing those adapters behave as "goldfish agents" — every run starts +from a blank slate, which is the difference between a usable +production agent and a demo. + +The mature adapters all delegate to an external ``AgentMemoryService`` +(``stratix.memory.models.MemoryEntry``). That contract works in the +``ateam`` monorepo where the service is available; for the SDK in +``stratix-python`` the service is **not** part of the runtime, so the +shared module here owns the in-process snapshot lifecycle and exposes +a ``MemorySnapshot`` shape that any external service can consume from +:meth:`BaseAdapter.serialize_for_replay` output. + +Design contract +=============== + +Three memory buckets, modelled after the canonical agent-memory +literature (LangChain memory module; CrewAI procedural memory; +AutoGen episodic/semantic split): + +* **Episodic** — per-turn input/output pairs, ordered by ``turn_index``. + Bounded ring (default 200 entries) — drops oldest on overflow. +* **Procedural** — learned recurring patterns derived from the + episodic stream (e.g. "tool X is called immediately after tool Y"). + Bounded by a per-pattern occurrence cap (default 16 unique patterns). +* **Semantic** — long-lived key/value facts that survive across many + conversations (e.g. user preferences, conversation summaries). + Bounded (default 64 entries) — least-recently-set eviction. + +Each :class:`MemorySnapshot` is **content-hash addressable** via a +SHA-256 of the canonical-JSON serialization of all three buckets plus +``turn_index`` and ``org_id``. Two snapshots with identical content +produce identical hashes — supports deduplication at the storage +layer. + +Replay safety contract +====================== + +A :class:`MemorySnapshot` is **deterministically restorable**: passing +the same snapshot to :meth:`MemoryRecorder.restore` and emitting the +same input sequence yields the same final snapshot. This means an +adapter that includes its current snapshot in +:meth:`BaseAdapter.serialize_for_replay` output can replay an agent +run with byte-exact memory state. + +Multi-tenancy +============= + +Every :class:`MemorySnapshot` and every :class:`MemoryRecorder` is +scoped to exactly one ``org_id``. A recorder constructed for tenant +A cannot ingest a snapshot from tenant B — :meth:`restore` raises +``ValueError`` on tenant mismatch. This mirrors the +:class:`BaseAdapter` org_id contract documented in +``docs/adapters/multi-tenancy.md``. +""" + +from __future__ import annotations + +import copy +import json +import time +import hashlib +import threading +from typing import Any, Dict, List, Tuple, Mapping, Optional +from dataclasses import dataclass + +# --------------------------------------------------------------------------- +# Defaults +# --------------------------------------------------------------------------- + +# Bounded buffers — CLAUDE.md "every cache must be bounded". These caps +# are conservative defaults; callers wanting a different size construct +# the recorder with explicit ``max_*`` kwargs. +DEFAULT_MAX_EPISODIC: int = 200 +DEFAULT_MAX_PROCEDURAL: int = 16 +DEFAULT_MAX_SEMANTIC: int = 64 + +# Per-turn truncation — the recorder is *not* the place to enforce +# field-size policy (that's the truncation module from cross-poll #3). +# But to prevent a single oversized turn from blowing past memory caps +# we apply a hard char-cap on individual values. This is a +# defence-in-depth limit, not a policy substitute. +_PER_FIELD_HARD_CAP: int = 8192 + +# Procedural pattern detection looks at ``_PROCEDURAL_WINDOW`` recent +# turns to find recurring (tool_name, next_action) pairs. +_PROCEDURAL_WINDOW: int = 16 + + +# --------------------------------------------------------------------------- +# Dataclass: MemorySnapshot +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class MemorySnapshot: + """Immutable, content-addressable memory snapshot. + + Snapshots are produced by :meth:`MemoryRecorder.snapshot` and + consumed by :meth:`MemoryRecorder.restore`. They are designed to: + + * Be **trivially serialisable** to JSON (every field is a primitive + or a dict/list of primitives). + * Be **content-addressable** via :attr:`content_hash` for + deduplication and integrity checks. + * Be **replay-safe**: restoring an adapter's recorder from a + snapshot and feeding the same input sequence reproduces the same + next snapshot. + * Be **multi-tenant-scoped** via :attr:`org_id` — snapshots from + tenant A cannot be restored into a tenant-B recorder. + + Attributes: + turn_index: Monotonic counter of completed turns at the moment + of the snapshot. Starts at 0 for an empty recorder. + episodic: Ordered list of recent turn dicts. Each turn carries + ``turn_index``, ``timestamp_ns``, ``agent_name``, ``input``, + ``output``, optional ``error``, and optional ``tools``. + procedural: Ordered list of detected patterns. Each pattern is + a dict of the form + ``{"pattern": [...], "count": int, "last_seen_turn": int}``. + semantic: Long-lived key/value store. Values are strings (or + stringified) — callers wanting structured semantic memory + should JSON-encode their value before storing. + content_hash: SHA-256 hex digest of the canonical JSON + representation of ``(turn_index, episodic, procedural, + semantic, org_id)``. Identical content → identical hash. + org_id: Tenant binding. Mirrors the adapter's bound ``org_id`` + and prevents cross-tenant restore. + """ + + turn_index: int + episodic: List[Dict[str, Any]] + procedural: List[Dict[str, Any]] + semantic: Dict[str, str] + content_hash: str + org_id: str + + def to_dict(self) -> Dict[str, Any]: + """Return a JSON-serialisable dict view of the snapshot. + + Used by :meth:`BaseAdapter.serialize_for_replay` when including + the memory state in a :class:`ReplayableTrace`. The shape is + stable: snapshot reconstruction reads the same keys back. + """ + return { + "turn_index": self.turn_index, + "episodic": copy.deepcopy(self.episodic), + "procedural": copy.deepcopy(self.procedural), + "semantic": dict(self.semantic), + "content_hash": self.content_hash, + "org_id": self.org_id, + } + + @classmethod + def from_dict(cls, data: Mapping[str, Any]) -> "MemorySnapshot": + """Reconstruct a snapshot from a previously-serialised dict. + + Args: + data: Mapping produced by :meth:`to_dict`. + + Returns: + A new :class:`MemorySnapshot` with deep-copied collections + so callers can mutate ``data`` afterwards without affecting + the snapshot. + + Raises: + ValueError: If ``data`` is missing required fields. + """ + for required in ("turn_index", "episodic", "procedural", "semantic", "content_hash", "org_id"): + if required not in data: + raise ValueError(f"MemorySnapshot.from_dict missing required field: {required}") + return cls( + turn_index=int(data["turn_index"]), + episodic=copy.deepcopy(list(data["episodic"])), + procedural=copy.deepcopy(list(data["procedural"])), + semantic=dict(data["semantic"]), + content_hash=str(data["content_hash"]), + org_id=str(data["org_id"]), + ) + + +# --------------------------------------------------------------------------- +# Empty-snapshot factory +# --------------------------------------------------------------------------- + + +def _empty_snapshot(org_id: str) -> MemorySnapshot: + """Build an empty snapshot bound to ``org_id``. + + Used for the initial state of a fresh :class:`MemoryRecorder` and + by tests verifying the empty-state hash invariant. + """ + return _build_snapshot( + turn_index=0, + episodic=[], + procedural=[], + semantic={}, + org_id=org_id, + ) + + +def _canonical_json(value: Any) -> str: + """Return a deterministic JSON encoding suitable for hashing. + + ``sort_keys=True`` + ``separators`` removes whitespace variance. + ``default=str`` is *not* used — non-JSON-safe inputs are a caller + bug and should raise so the hash never silently changes shape. + """ + return json.dumps(value, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + + +def _compute_content_hash( + *, + turn_index: int, + episodic: List[Dict[str, Any]], + procedural: List[Dict[str, Any]], + semantic: Dict[str, str], + org_id: str, +) -> str: + """Compute the SHA-256 content hash for a snapshot. + + The exact bucket order and key set is part of the public contract: + changing it breaks dedup against historical snapshots. If the + content shape ever needs to grow, do it via a versioned wrapper + around the raw hash, not by mutating this function. + """ + payload = { + "turn_index": turn_index, + "episodic": episodic, + "procedural": procedural, + "semantic": semantic, + "org_id": org_id, + } + encoded = _canonical_json(payload).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def _build_snapshot( + *, + turn_index: int, + episodic: List[Dict[str, Any]], + procedural: List[Dict[str, Any]], + semantic: Dict[str, str], + org_id: str, +) -> MemorySnapshot: + """Construct a :class:`MemorySnapshot` with a freshly-computed hash.""" + content_hash = _compute_content_hash( + turn_index=turn_index, + episodic=episodic, + procedural=procedural, + semantic=semantic, + org_id=org_id, + ) + return MemorySnapshot( + turn_index=turn_index, + episodic=copy.deepcopy(episodic), + procedural=copy.deepcopy(procedural), + semantic=dict(semantic), + content_hash=content_hash, + org_id=org_id, + ) + + +# --------------------------------------------------------------------------- +# Helper: cap a value at the per-field hard cap +# --------------------------------------------------------------------------- + + +def _cap_value(value: Any) -> Any: + """Apply the defence-in-depth per-field char cap. + + Returns the value unchanged for non-string types and short + strings; truncates long strings with a deterministic suffix that + makes the truncation visible in downstream tooling. + """ + if isinstance(value, str) and len(value) > _PER_FIELD_HARD_CAP: + # The suffix records the original length so reviewers can see + # how much was elided; the deterministic shape (no timestamps, + # no random IDs) preserves replay determinism. + return value[:_PER_FIELD_HARD_CAP] + f"<...truncated:orig_len={len(value)}>" + return value + + +def _normalise_turn( + *, + turn_index: int, + timestamp_ns: int, + agent_name: Optional[str], + input_data: Any, + output_data: Any, + error: Optional[str], + tools: Optional[List[str]], + extra: Optional[Mapping[str, Any]], +) -> Dict[str, Any]: + """Build the canonical episodic-turn dict. + + Every recorded turn passes through this function so the schema is + enforced at one place. Non-JSON-safe inputs are coerced via + ``str()`` — adapters that want richer fidelity should serialise + upstream. + """ + turn: Dict[str, Any] = { + "turn_index": turn_index, + "timestamp_ns": timestamp_ns, + "agent_name": agent_name or "", + "input": _cap_value(input_data if isinstance(input_data, (str, int, float, bool, type(None))) else str(input_data)), + "output": _cap_value(output_data if isinstance(output_data, (str, int, float, bool, type(None))) else str(output_data)), + } + if error is not None: + turn["error"] = _cap_value(error) + if tools: + # Cap the tool list itself (not the strings inside) at a sane + # ceiling to prevent runaway tool-name accumulation. + turn["tools"] = [str(t) for t in tools[:32]] + if extra: + # ``extra`` is opt-in metadata. Keys are sorted to keep the + # hash deterministic even if callers pass dict literals. + normalised_extra: Dict[str, Any] = {} + for k in sorted(str(k) for k in extra.keys()): + normalised_extra[k] = _cap_value(extra[k] if isinstance(extra[k], (str, int, float, bool, type(None))) else str(extra[k])) + turn["extra"] = normalised_extra + return turn + + +# --------------------------------------------------------------------------- +# MemoryRecorder +# --------------------------------------------------------------------------- + + +class MemoryRecorder: + """Thread-safe accumulator wired into adapter lifecycle hooks. + + A :class:`MemoryRecorder` lives for the lifetime of an adapter + instance (one per adapter). The adapter calls + :meth:`record_turn` after every per-turn callback (typically right + after emitting ``agent.output``). At any point — most commonly + inside :meth:`BaseAdapter.serialize_for_replay` — the adapter + calls :meth:`snapshot` to obtain an immutable, hashable + :class:`MemorySnapshot` and embeds it in the replay trace. + + The recorder enforces the multi-tenant binding: it is constructed + with the adapter's ``org_id`` and refuses to restore from a + snapshot whose ``org_id`` does not match. + + Args: + org_id: Tenant binding. Must be a non-empty string. Mirrors + :attr:`BaseAdapter.org_id`. + max_episodic: Maximum number of episodic turns retained. Older + turns are dropped FIFO when the cap is reached. + max_procedural: Maximum number of distinct procedural patterns + retained. + max_semantic: Maximum number of semantic key/value entries + retained. Eviction is least-recently-set. + """ + + def __init__( + self, + *, + org_id: str, + max_episodic: int = DEFAULT_MAX_EPISODIC, + max_procedural: int = DEFAULT_MAX_PROCEDURAL, + max_semantic: int = DEFAULT_MAX_SEMANTIC, + ) -> None: + if not isinstance(org_id, str) or not org_id.strip(): + raise ValueError( + "MemoryRecorder requires a non-empty org_id for multi-tenant " + "scoping. Pass the adapter's org_id (CLAUDE.md multi-tenancy)." + ) + if max_episodic < 1 or max_procedural < 1 or max_semantic < 1: + raise ValueError("MemoryRecorder bounded buffer sizes must be >= 1") + + self._org_id: str = org_id + self._max_episodic: int = max_episodic + self._max_procedural: int = max_procedural + self._max_semantic: int = max_semantic + + self._lock = threading.Lock() + self._turn_index: int = 0 + self._episodic: List[Dict[str, Any]] = [] + # Procedural store keyed by canonical pattern string so we + # detect repetition O(1). + self._procedural: Dict[str, Dict[str, Any]] = {} + # Semantic store keyed by user-supplied key. Insertion order + # tracked separately for LRU eviction (3.7+ dicts preserve + # insertion order, but we re-insert on update so eviction is + # least-recently-*set*, matching the documented contract). + self._semantic: Dict[str, str] = {} + + # --- Properties (read-only) ------------------------------------ + + @property + def org_id(self) -> str: + """The tenant binding fixed at construction.""" + return self._org_id + + @property + def turn_index(self) -> int: + """Monotonic count of completed turns. Thread-safe read. + + Returned by value: the underlying counter is locked during + :meth:`record_turn` and :meth:`restore`, but a read here is a + plain int load. + """ + return self._turn_index + + # --- Recording ------------------------------------------------- + + def record_turn( + self, + *, + agent_name: Optional[str] = None, + input_data: Any = None, + output_data: Any = None, + error: Optional[str] = None, + tools: Optional[List[str]] = None, + extra: Optional[Mapping[str, Any]] = None, + ) -> int: + """Record one completed agent turn. + + Adapters wire this into their lifecycle right after the + ``agent.output`` event emission. The turn data is enrolled in + the episodic buffer; procedural patterns are detected from the + recent window; the monotonic counter is incremented and + returned. + + Args: + agent_name: The agent that produced this turn (e.g. + ``"researcher"``). + input_data: The input the agent received. + output_data: The output the agent produced. + error: Optional error message if the turn failed. + tools: Optional list of tool names invoked during the turn. + extra: Optional additional metadata. Keys are sorted to + keep snapshot hashes deterministic. + + Returns: + The new ``turn_index`` after recording. + """ + with self._lock: + self._turn_index += 1 + turn = _normalise_turn( + turn_index=self._turn_index, + timestamp_ns=time.time_ns(), + agent_name=agent_name, + input_data=input_data, + output_data=output_data, + error=error, + tools=list(tools) if tools else None, + extra=extra, + ) + self._episodic.append(turn) + # Bounded: drop oldest turns FIFO. We use slicing rather + # than ``deque`` because episodic is also serialised as a + # plain list and we want zero-cost view semantics. + if len(self._episodic) > self._max_episodic: + drop = len(self._episodic) - self._max_episodic + del self._episodic[:drop] + self._detect_procedural_patterns() + return self._turn_index + + def set_semantic(self, key: str, value: str) -> None: + """Set or overwrite a long-lived semantic memory entry. + + Keys are sorted at snapshot time. Values are coerced to + strings; structured semantic data should be JSON-encoded by + the caller. + + Args: + key: Semantic memory key. Must be non-empty. + value: Value to store. Coerced via ``str()`` if not + already a string. Hard-capped at + :data:`_PER_FIELD_HARD_CAP` characters. + + Raises: + ValueError: If ``key`` is empty. + """ + if not isinstance(key, str) or not key.strip(): + raise ValueError("MemoryRecorder.set_semantic requires a non-empty key") + capped = _cap_value(value if isinstance(value, str) else str(value)) + with self._lock: + # Re-insert to refresh LRU order (3.7+ dicts). + if key in self._semantic: + del self._semantic[key] + self._semantic[key] = capped + # Bounded: evict least-recently-set entries. + if len(self._semantic) > self._max_semantic: + # ``next(iter(d))`` returns the first (oldest) + # insertion key — LRU eviction. + drop = len(self._semantic) - self._max_semantic + for _ in range(drop): + oldest = next(iter(self._semantic)) + del self._semantic[oldest] + + def clear(self) -> None: + """Reset all buckets to empty; ``turn_index`` resets to 0. + + Useful for adapters that want a fresh memory state when a new + conversation/session begins. Multi-tenant binding is + preserved. + """ + with self._lock: + self._turn_index = 0 + self._episodic.clear() + self._procedural.clear() + self._semantic.clear() + + # --- Snapshots -------------------------------------------------- + + def snapshot(self) -> MemorySnapshot: + """Return an immutable, content-addressable snapshot. + + Adapters should call this from + :meth:`BaseAdapter.serialize_for_replay` to embed the memory + state in the replay trace. The returned snapshot is a deep + copy — mutating the recorder afterwards never affects a + previously-returned snapshot (immutability invariant). + """ + with self._lock: + procedural_list = self._procedural_as_sorted_list() + return _build_snapshot( + turn_index=self._turn_index, + episodic=self._episodic, + procedural=procedural_list, + semantic=self._semantic, + org_id=self._org_id, + ) + + def restore(self, snapshot: MemorySnapshot) -> None: + """Replace the recorder's state with a previously-taken snapshot. + + The recorder is rebuilt to byte-exact equivalence: a fresh + :meth:`snapshot` immediately after a :meth:`restore` returns a + snapshot with the same :attr:`MemorySnapshot.content_hash` + (deterministic round-trip). This is the foundation of + replay-safe memory: the replay engine restores the recorder, + then the adapter re-runs the agent and produces the same + next-turn snapshot. + + Args: + snapshot: The :class:`MemorySnapshot` to restore from. + + Raises: + ValueError: If ``snapshot.org_id`` does not match the + recorder's tenant binding (cross-tenant restore is + prohibited). + ValueError: If the snapshot's recorded + :attr:`MemorySnapshot.content_hash` does not match the + hash recomputed from its content (integrity check). + """ + if snapshot.org_id != self._org_id: + raise ValueError( + f"MemoryRecorder.restore: snapshot org_id={snapshot.org_id!r} " + f"does not match recorder org_id={self._org_id!r}. " + "Cross-tenant restore is prohibited (CLAUDE.md multi-tenancy)." + ) + # Verify the snapshot's stored hash matches its content. Guards + # against accidentally-mutated dicts in transit. + recomputed = _compute_content_hash( + turn_index=snapshot.turn_index, + episodic=snapshot.episodic, + procedural=snapshot.procedural, + semantic=snapshot.semantic, + org_id=snapshot.org_id, + ) + if recomputed != snapshot.content_hash: + raise ValueError( + "MemoryRecorder.restore: snapshot content_hash mismatch — " + "snapshot has been tampered with or is corrupted. " + f"Recorded={snapshot.content_hash} recomputed={recomputed}." + ) + with self._lock: + self._turn_index = snapshot.turn_index + self._episodic = copy.deepcopy(snapshot.episodic) + self._semantic = dict(snapshot.semantic) + # Procedural store is keyed internally; rebuild the dict + # form from the snapshot's list form. + self._procedural = {} + for entry in snapshot.procedural: + pattern_key = _canonical_json(entry["pattern"]) + self._procedural[pattern_key] = { + "pattern": list(entry["pattern"]), + "count": int(entry["count"]), + "last_seen_turn": int(entry["last_seen_turn"]), + } + + # --- Internal: procedural-pattern detection -------------------- + + def _detect_procedural_patterns(self) -> None: + """Scan recent turns for recurring tool sequences. + + Caller MUST hold ``self._lock``. The detector looks at the + last :data:`_PROCEDURAL_WINDOW` turns and records any + ``(prev_turn_tools, current_turn_tools)`` pair that recurs. + + The detection is deliberately simple — pairwise tool-list + sequences only — to keep the algorithm O(window) per turn. + Adapters wanting richer pattern detection can layer their own + analysis on top of the episodic stream returned by + :meth:`snapshot`. + """ + if len(self._episodic) < 2: + return + window = self._episodic[-_PROCEDURAL_WINDOW:] + for i in range(1, len(window)): + prev_tools = window[i - 1].get("tools") or [] + cur_tools = window[i].get("tools") or [] + if not prev_tools and not cur_tools: + continue + pattern: List[List[str]] = [list(prev_tools), list(cur_tools)] + pattern_key = _canonical_json(pattern) + existing = self._procedural.get(pattern_key) + if existing is not None: + existing["count"] += 1 + existing["last_seen_turn"] = window[i]["turn_index"] + else: + self._procedural[pattern_key] = { + "pattern": pattern, + "count": 1, + "last_seen_turn": window[i]["turn_index"], + } + # Bounded: keep only top-N most-frequent patterns. Ties broken + # by ``last_seen_turn`` (more recent wins) — deterministic. + if len(self._procedural) > self._max_procedural: + ranked: List[Tuple[str, Dict[str, Any]]] = sorted( + self._procedural.items(), + key=lambda kv: (-kv[1]["count"], -kv[1]["last_seen_turn"]), + ) + self._procedural = dict(ranked[: self._max_procedural]) + + def _procedural_as_sorted_list(self) -> List[Dict[str, Any]]: + """Return procedural store as a deterministically-ordered list. + + Caller MUST hold ``self._lock``. Sort key is + ``(-count, -last_seen_turn, canonical_pattern_json)`` so the + snapshot's list order is independent of insertion order → + identical content yields identical hashes. + """ + # Build a typed (sort_key_tuple, payload_dict) list so mypy --strict + # can prove the sort lambda inputs are ints / strings (not the + # ``object`` widening that ``Dict[str, Any]`` produces). + ranked: List[Tuple[Tuple[int, int, str], Dict[str, Any]]] = [] + for pattern_key, entry in self._procedural.items(): + count = int(entry["count"]) + last_seen = int(entry["last_seen_turn"]) + payload: Dict[str, Any] = { + "pattern": list(entry["pattern"]), + "count": count, + "last_seen_turn": last_seen, + } + # Negate count/last_seen so the natural ascending tuple sort + # produces "highest count first, then most recent first". + ranked.append(((-count, -last_seen, pattern_key), payload)) + ranked.sort(key=lambda item: item[0]) + return [payload for _, payload in ranked] + + +# --------------------------------------------------------------------------- +# Public re-exports +# --------------------------------------------------------------------------- + + +__all__ = [ + "DEFAULT_MAX_EPISODIC", + "DEFAULT_MAX_PROCEDURAL", + "DEFAULT_MAX_SEMANTIC", + "MemoryRecorder", + "MemorySnapshot", +] diff --git a/src/layerlens/instrument/adapters/frameworks/agno/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/agno/lifecycle.py index b940b62..1011bea 100644 --- a/src/layerlens/instrument/adapters/frameworks/agno/lifecycle.py +++ b/src/layerlens/instrument/adapters/frameworks/agno/lifecycle.py @@ -122,7 +122,13 @@ def get_adapter_info(self) -> AdapterInfo: ) def serialize_for_replay(self) -> ReplayableTrace: - """Serialize the current trace data for replay.""" + """Serialize the current trace data for replay. + + Includes the per-adapter memory snapshot (cross-poll #1) under + ``metadata["memory_snapshot"]`` so the replay engine can call + :meth:`MemoryRecorder.restore` before re-execution and reach + byte-identical memory state. + """ return ReplayableTrace( adapter_name="AgnoAdapter", framework=self.FRAMEWORK, @@ -130,6 +136,7 @@ def serialize_for_replay(self) -> ReplayableTrace: events=list(self._trace_events), state_snapshots=[], config={"capture_config": self._capture_config.model_dump()}, + metadata={"memory_snapshot": self.memory_snapshot_dict()}, ) # --- Framework Integration --- @@ -173,7 +180,16 @@ async def traced_run(*args: Any, **kwargs: Any) -> Any: output = None if result is not None: output = getattr(result, "content", result) - adapter.on_run_end(agent_name=agent_name, output=output, error=error) + # Surface input + tool list to the memory recorder so the + # episodic and procedural buckets capture the full turn. + tool_names = adapter._collect_tool_names(agent, result) + adapter.on_run_end( + agent_name=agent_name, + output=output, + error=error, + input_data=input_data, + tools=tool_names, + ) adapter._extract_run_details(agent, result) return result @@ -199,7 +215,16 @@ def traced_run_sync(*args: Any, **kwargs: Any) -> Any: output = None if result is not None: output = getattr(result, "content", result) - adapter.on_run_end(agent_name=agent_name, output=output, error=error) + # Surface input + tool list to the memory recorder so the + # episodic and procedural buckets capture the full turn. + tool_names = adapter._collect_tool_names(agent, result) + adapter.on_run_end( + agent_name=agent_name, + output=output, + error=error, + input_data=input_data, + tools=tool_names, + ) adapter._extract_run_details(agent, result) return result @@ -304,6 +329,8 @@ def on_run_end( agent_name: str | None = None, output: Any = None, error: Exception | None = None, + input_data: Any = None, + tools: list[str] | None = None, ) -> None: """Emit agent.output event when an agent run ends.""" if not self._connected: @@ -331,6 +358,18 @@ def on_run_end( "event_subtype": "run_complete" if not error else "run_failed", }, ) + # Cross-poll #1: persist this turn into the per-adapter + # memory recorder. Episodic = the input/output pair; + # procedural patterns are detected from recurring tool + # sequences in the recorder. Recording is best-effort and + # never breaks the host run path (BaseAdapter swallows). + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + error=str(error) if error else None, + tools=tools, + ) except Exception: logger.warning("Error in on_run_end", exc_info=True) @@ -460,6 +499,36 @@ def _emit_agent_config(self, agent_name: str, agent: Any) -> None: metadata["team_members"] = [getattr(m, "name", str(m)) for m in members] self.emit_dict_event("environment.config", metadata) + def _collect_tool_names(self, agent: Any, result: Any) -> list[str]: # noqa: ARG002 + """Collect tool names invoked during the run for memory persistence. + + Used by the memory recorder's procedural-pattern detector — a + short, stable list of tool names is preferable to the full tool + argument blob. + + ``agent`` is accepted for symmetry with peer ``_collect_*``/ + ``_extract_run_details`` helpers and to leave the door open for + agent-side tool introspection (e.g. ``agent.tools`` discovery) + if the result object stops carrying tool calls in a future Agno + release. Currently unused but kept on the contract. + """ + names: list[str] = [] + try: + messages = getattr(result, "messages", None) or [] + for msg in messages: + tool_calls = getattr(msg, "tool_calls", None) or [] + for tc in tool_calls: + fn = getattr(tc, "function", None) + if isinstance(fn, dict): + name = fn.get("name") + else: + name = getattr(fn, "name", None) + if name: + names.append(str(name)) + except Exception: + logger.debug("Could not collect tool names", exc_info=True) + return names + def _safe_serialize(self, value: Any) -> Any: """Safely serialize a value for event payloads.""" try: diff --git a/src/layerlens/instrument/adapters/frameworks/bedrock_agents/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/bedrock_agents/lifecycle.py index aa7dbea..e5019b5 100644 --- a/src/layerlens/instrument/adapters/frameworks/bedrock_agents/lifecycle.py +++ b/src/layerlens/instrument/adapters/frameworks/bedrock_agents/lifecycle.py @@ -58,6 +58,12 @@ def __init__( self._seen_agents: set[str] = set() self._framework_version: str | None = None self._invoke_starts: dict[int, int] = {} + # Per-thread cache of input + tool list for memory persistence. + # boto3 hooks fire on the thread that invoked the SDK call, so + # thread-id keying is safe and matches _invoke_starts. + self._invoke_inputs: dict[int, Any] = {} + self._invoke_agent_ids: dict[int, str] = {} + self._invoke_tools: dict[int, list[str]] = {} def connect(self) -> None: try: @@ -87,6 +93,9 @@ def disconnect(self) -> None: logger.debug("Could not unregister boto3 event hooks", exc_info=True) self._originals.clear() self._seen_agents.clear() + self._invoke_inputs.clear() + self._invoke_agent_ids.clear() + self._invoke_tools.clear() self._connected = False self._status = AdapterStatus.DISCONNECTED @@ -116,6 +125,13 @@ def get_adapter_info(self) -> AdapterInfo: ) def serialize_for_replay(self) -> ReplayableTrace: + """Serialize the current trace data for replay. + + Includes the per-adapter memory snapshot (cross-poll #1) under + ``metadata["memory_snapshot"]``. Bedrock session attributes are + consolidated into the snapshot at after_invoke so a replay + engine can reconstruct cross-session recall. + """ return ReplayableTrace( adapter_name="BedrockAgentsAdapter", framework=self.FRAMEWORK, @@ -123,6 +139,7 @@ def serialize_for_replay(self) -> ReplayableTrace: events=list(self._trace_events), state_snapshots=[], config={"capture_config": self._capture_config.model_dump()}, + metadata={"memory_snapshot": self.memory_snapshot_dict()}, ) # --- Framework Integration --- @@ -153,9 +170,13 @@ def _before_invoke_agent(self, **kwargs: Any) -> None: params = kwargs.get("params", {}) tid = threading.get_ident() start_ns = time.time_ns() + agent_id = params.get("agentId", "unknown") + input_text = params.get("inputText") with self._adapter_lock: self._invoke_starts[tid] = start_ns - agent_id = params.get("agentId", "unknown") + self._invoke_inputs[tid] = input_text + self._invoke_agent_ids[tid] = agent_id + self._invoke_tools[tid] = [] self._emit_agent_config(agent_id, params) self.emit_dict_event( "agent.input", @@ -163,7 +184,7 @@ def _before_invoke_agent(self, **kwargs: Any) -> None: "framework": "bedrock_agents", "agent_id": agent_id, "session_id": params.get("sessionId"), - "input": params.get("inputText"), + "input": input_text, "enable_trace": params.get("enableTrace", False), "timestamp_ns": start_ns, }, @@ -180,6 +201,9 @@ def _after_invoke_agent(self, **kwargs: Any) -> None: end_ns = time.time_ns() with self._adapter_lock: start_ns = self._invoke_starts.pop(tid, 0) + input_text = self._invoke_inputs.pop(tid, None) + agent_id = self._invoke_agent_ids.pop(tid, "") + tool_names = self._invoke_tools.pop(tid, []) duration_ns = end_ns - start_ns if start_ns else 0 output = self._extract_completion(parsed) self.emit_dict_event( @@ -191,13 +215,36 @@ def _after_invoke_agent(self, **kwargs: Any) -> None: "session_id": parsed.get("sessionId"), }, ) - # Extract trace steps if available - self._process_trace(parsed) + # Extract trace steps if available — also populates + # _invoke_tools[tid] for the in-flight invocation when + # supervisor/collaborator orchestration emits action-group + # steps. The caller's ``tool_names`` list is appended to + # in-place by _process_trace. + self._process_trace(parsed, tool_names=tool_names) + # Cross-poll #1: persist this turn into memory. + self.record_memory_turn( + agent_name=agent_id, + input_data=self._safe_serialize(input_text), + output_data=self._safe_serialize(output), + tools=tool_names or None, + ) except Exception: logger.warning("Error in _after_invoke_agent", exc_info=True) - def _process_trace(self, parsed: dict[str, Any]) -> None: - """Extract trace steps from Bedrock response and emit events.""" + def _process_trace( + self, + parsed: dict[str, Any], + tool_names: list[str] | None = None, + ) -> None: + """Extract trace steps from Bedrock response and emit events. + + Args: + parsed: The boto3-parsed InvokeAgent response. + tool_names: Mutable list the caller passed in to receive + the tool-name roll-up. When supplied, action-group / + knowledge-base step names are appended for memory + persistence at the caller's :meth:`record_memory_turn`. + """ trace = parsed.get("trace", {}) steps = trace.get("trace", {}).get("orchestrationTrace", {}).get("steps", []) if not steps and isinstance(trace, dict): @@ -207,8 +254,15 @@ def _process_trace(self, parsed: dict[str, Any]) -> None: step_type = step.get("type", "") if step_type == "ACTION_GROUP": self._emit_action_group(step) + if tool_names is not None: + name = step.get("actionGroupName") + if name: + tool_names.append(str(name)) elif step_type == "KNOWLEDGE_BASE": self._emit_knowledge_base(step) + if tool_names is not None: + name = step.get("knowledgeBaseId") or "knowledge_base" + tool_names.append(str(name)) elif step_type == "MODEL_INVOCATION": self._emit_model_invocation(step) elif step_type == "AGENT_COLLABORATOR": @@ -304,6 +358,8 @@ def on_invoke_end( agent_id: str | None = None, output: Any = None, error: Exception | None = None, + input_data: Any = None, + tools: list[str] | None = None, ) -> None: if not self._connected: return @@ -312,6 +368,15 @@ def on_invoke_end( end_ns = time.time_ns() with self._adapter_lock: start_ns = self._invoke_starts.pop(tid, 0) + if input_data is None: + input_data = self._invoke_inputs.pop(tid, None) + else: + self._invoke_inputs.pop(tid, None) + if tools is None: + tools = self._invoke_tools.pop(tid, []) or None + else: + self._invoke_tools.pop(tid, None) + self._invoke_agent_ids.pop(tid, None) duration_ns = end_ns - start_ns if start_ns else 0 payload: dict[str, Any] = { "framework": "bedrock_agents", @@ -322,6 +387,14 @@ def on_invoke_end( if error: payload["error"] = str(error) self.emit_dict_event("agent.output", payload) + # Cross-poll #1: persist this turn into memory. + self.record_memory_turn( + agent_name=agent_id, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + error=str(error) if error else None, + tools=tools, + ) except Exception: logger.warning("Error in on_invoke_end", exc_info=True) diff --git a/src/layerlens/instrument/adapters/frameworks/google_adk/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/google_adk/lifecycle.py index 5de8ea3..7f995f0 100644 --- a/src/layerlens/instrument/adapters/frameworks/google_adk/lifecycle.py +++ b/src/layerlens/instrument/adapters/frameworks/google_adk/lifecycle.py @@ -61,6 +61,11 @@ def __init__( self._model_call_starts: dict[int, int] = {} # thread_id -> start_ns self._tool_call_starts: dict[str, int] = {} self._agent_starts: dict[int, int] = {} # thread_id -> start_ns + # Per-thread state for memory persistence — mirrors the + # _agent_starts thread-id key. ADK callbacks fire on the same + # thread that drives the agent so this is a natural fit. + self._agent_inputs: dict[int, Any] = {} + self._agent_tools: dict[int, list[str]] = {} def connect(self) -> None: try: @@ -83,6 +88,8 @@ def disconnect(self) -> None: self._model_call_starts.clear() self._tool_call_starts.clear() self._agent_starts.clear() + self._agent_inputs.clear() + self._agent_tools.clear() self._connected = False self._status = AdapterStatus.DISCONNECTED @@ -112,6 +119,13 @@ def get_adapter_info(self) -> AdapterInfo: ) def serialize_for_replay(self) -> ReplayableTrace: + """Serialize the current trace data for replay. + + Includes the per-adapter memory snapshot (cross-poll #1) under + ``metadata["memory_snapshot"]``. ADK's session/state machinery + is consolidated into the snapshot at after_agent_callback so a + replay engine can reconstruct conversational state. + """ return ReplayableTrace( adapter_name="GoogleADKAdapter", framework=self.FRAMEWORK, @@ -119,6 +133,7 @@ def serialize_for_replay(self) -> ReplayableTrace: events=list(self._trace_events), state_snapshots=[], config={"capture_config": self._capture_config.model_dump()}, + metadata={"memory_snapshot": self.memory_snapshot_dict()}, ) # --- Framework Integration --- @@ -146,14 +161,17 @@ def _before_agent_callback(self, callback_context: Any) -> Any: self._emit_agent_config(agent_name, callback_context) tid = threading.get_ident() start_ns = time.time_ns() + user_content = getattr(callback_context, "user_content", None) with self._adapter_lock: self._agent_starts[tid] = start_ns + self._agent_inputs[tid] = user_content + self._agent_tools[tid] = [] self.emit_dict_event( "agent.input", { "framework": "google_adk", "agent_name": agent_name, - "input": self._safe_serialize(getattr(callback_context, "user_content", None)), + "input": self._safe_serialize(user_content), "timestamp_ns": start_ns, }, ) @@ -170,16 +188,26 @@ def _after_agent_callback(self, callback_context: Any) -> Any: end_ns = time.time_ns() with self._adapter_lock: start_ns = self._agent_starts.pop(tid, 0) + user_content = self._agent_inputs.pop(tid, None) + tool_names = self._agent_tools.pop(tid, []) duration_ns = end_ns - start_ns if start_ns else 0 + agent_output = getattr(callback_context, "agent_output", None) self.emit_dict_event( "agent.output", { "framework": "google_adk", "agent_name": agent_name, - "output": self._safe_serialize(getattr(callback_context, "agent_output", None)), + "output": self._safe_serialize(agent_output), "duration_ns": duration_ns, }, ) + # Cross-poll #1: persist this turn into memory. + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(user_content), + output_data=self._safe_serialize(agent_output), + tools=tool_names or None, + ) except Exception: logger.warning("Error in after_agent_callback", exc_info=True) return None @@ -269,6 +297,14 @@ def _after_tool_callback( "latency_ms": latency_ms, }, ) + # Memory persistence: record this tool name under the + # in-flight agent turn so the procedural-pattern detector + # sees it. Same thread as the agent that fired the tool. + tid = threading.get_ident() + with self._adapter_lock: + bucket = self._agent_tools.get(tid) + if bucket is not None: + bucket.append(str(tool_name)) except Exception: logger.warning("Error in after_tool_callback", exc_info=True) return None @@ -300,6 +336,8 @@ def on_agent_end( agent_name: str | None = None, output: Any = None, error: Exception | None = None, + input_data: Any = None, + tools: list[str] | None = None, ) -> None: if not self._connected: return @@ -308,6 +346,14 @@ def on_agent_end( end_ns = time.time_ns() with self._adapter_lock: start_ns = self._agent_starts.pop(tid, 0) + if input_data is None: + input_data = self._agent_inputs.pop(tid, None) + else: + self._agent_inputs.pop(tid, None) + if tools is None: + tools = self._agent_tools.pop(tid, []) or None + else: + self._agent_tools.pop(tid, None) duration_ns = end_ns - start_ns if start_ns else 0 payload: dict[str, Any] = { "framework": "google_adk", @@ -318,6 +364,14 @@ def on_agent_end( if error: payload["error"] = str(error) self.emit_dict_event("agent.output", payload) + # Cross-poll #1: persist this turn into memory. + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + error=str(error) if error else None, + tools=tools, + ) except Exception: logger.warning("Error in on_agent_end", exc_info=True) diff --git a/src/layerlens/instrument/adapters/frameworks/llama_index/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/llama_index/lifecycle.py index acbf1a9..4aded42 100644 --- a/src/layerlens/instrument/adapters/frameworks/llama_index/lifecycle.py +++ b/src/layerlens/instrument/adapters/frameworks/llama_index/lifecycle.py @@ -60,6 +60,11 @@ def __init__( self._framework_version: str | None = None self._event_handler: Any | None = None self._agent_starts: dict[int, int] = {} # thread_id -> start_ns + # Per-thread cache of agent input + tool list for memory persistence. + # LlamaIndex events arrive on the thread running the workflow, + # so thread-id keying matches the existing _agent_starts pattern. + self._agent_inputs: dict[int, Any] = {} + self._agent_tools: dict[int, list[str]] = {} def connect(self) -> None: try: @@ -93,6 +98,8 @@ def disconnect(self) -> None: self._event_handler = None self._originals.clear() self._seen_agents.clear() + self._agent_inputs.clear() + self._agent_tools.clear() self._connected = False self._status = AdapterStatus.DISCONNECTED @@ -122,6 +129,13 @@ def get_adapter_info(self) -> AdapterInfo: ) def serialize_for_replay(self) -> ReplayableTrace: + """Serialize the current trace data for replay. + + Includes the per-adapter memory snapshot (cross-poll #1) under + ``metadata["memory_snapshot"]``. LlamaIndex's ``ChatMemoryBuffer`` + is consolidated into the snapshot at workflow end so a replay + engine can re-hydrate cross-conversation context. + """ return ReplayableTrace( adapter_name="LlamaIndexAdapter", framework=self.FRAMEWORK, @@ -129,6 +143,7 @@ def serialize_for_replay(self) -> ReplayableTrace: events=list(self._trace_events), state_snapshots=[], config={"capture_config": self._capture_config.model_dump()}, + metadata={"memory_snapshot": self.memory_snapshot_dict()}, ) # --- Framework Integration --- @@ -229,15 +244,24 @@ def _on_llm_end(self, event: Any) -> None: ) def _on_tool_call(self, event: Any) -> None: + tool_name = getattr(event, "tool_name", None) or getattr(event, "name", "unknown") self.emit_dict_event( "tool.call", { "framework": "llama_index", - "tool_name": getattr(event, "tool_name", None) or getattr(event, "name", "unknown"), + "tool_name": tool_name, "tool_input": self._safe_serialize(getattr(event, "tool_input", None)), "tool_output": self._safe_serialize(getattr(event, "tool_output", None)), }, ) + # Memory persistence: record this tool name under the in-flight + # agent step so the procedural-pattern detector sees it. Same + # thread as the agent step that triggered the tool call. + tid = threading.get_ident() + with self._adapter_lock: + bucket = self._agent_tools.get(tid) + if bucket is not None: + bucket.append(str(tool_name)) def _on_retrieval_start(self, event: Any) -> None: pass # Tracked on end @@ -260,8 +284,17 @@ def _on_agent_step_start(self, event: Any) -> None: self._emit_agent_config(agent_name, event) tid = threading.get_ident() start_ns = time.time_ns() + # Stash the step input + reset the per-step tool list for memory + # persistence at step_end. + step_input = ( + getattr(event, "input", None) + or getattr(event, "step", None) + or getattr(event, "task", None) + ) with self._adapter_lock: self._agent_starts[tid] = start_ns + self._agent_inputs[tid] = step_input + self._agent_tools[tid] = [] self.emit_dict_event( "agent.input", { @@ -278,16 +311,26 @@ def _on_agent_step_end(self, event: Any) -> None: end_ns = time.time_ns() with self._adapter_lock: start_ns = self._agent_starts.pop(tid, 0) + step_input = self._agent_inputs.pop(tid, None) + tool_names = self._agent_tools.pop(tid, []) duration_ns = end_ns - start_ns if start_ns else 0 + response = getattr(event, "response", None) self.emit_dict_event( "agent.output", { "framework": "llama_index", "agent_name": agent_name, - "output": self._safe_serialize(getattr(event, "response", None)), + "output": self._safe_serialize(response), "duration_ns": duration_ns, }, ) + # Cross-poll #1: persist this step into memory. + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(step_input), + output_data=self._safe_serialize(response), + tools=tool_names or None, + ) # --- Lifecycle Hooks --- @@ -316,6 +359,8 @@ def on_agent_end( agent_name: str | None = None, output: Any = None, error: Exception | None = None, + input_data: Any = None, + tools: list[str] | None = None, ) -> None: if not self._connected: return @@ -324,6 +369,16 @@ def on_agent_end( end_ns = time.time_ns() with self._adapter_lock: start_ns = self._agent_starts.pop(tid, 0) + # Pull thread-cached state if the public hook was used + # without the explicit kwargs. + if input_data is None: + input_data = self._agent_inputs.pop(tid, None) + else: + self._agent_inputs.pop(tid, None) + if tools is None: + tools = self._agent_tools.pop(tid, []) or None + else: + self._agent_tools.pop(tid, None) duration_ns = end_ns - start_ns if start_ns else 0 payload: dict[str, Any] = { "framework": "llama_index", @@ -334,6 +389,14 @@ def on_agent_end( if error: payload["error"] = str(error) self.emit_dict_event("agent.output", payload) + # Cross-poll #1: persist this turn into memory. + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + error=str(error) if error else None, + tools=tools, + ) except Exception: logger.warning("Error in on_agent_end", exc_info=True) diff --git a/src/layerlens/instrument/adapters/frameworks/ms_agent_framework/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/ms_agent_framework/lifecycle.py index ea7c153..87eda9f 100644 --- a/src/layerlens/instrument/adapters/frameworks/ms_agent_framework/lifecycle.py +++ b/src/layerlens/instrument/adapters/frameworks/ms_agent_framework/lifecycle.py @@ -127,7 +127,12 @@ def get_adapter_info(self) -> AdapterInfo: ) def serialize_for_replay(self) -> ReplayableTrace: - """Serialize the current trace data for replay.""" + """Serialize the current trace data for replay. + + Includes the per-adapter memory snapshot (cross-poll #1) under + ``metadata["memory_snapshot"]``. Powers replay re-execution with + a deterministically-restored memory state. + """ return ReplayableTrace( adapter_name="MSAgentAdapter", framework=self.FRAMEWORK, @@ -135,6 +140,7 @@ def serialize_for_replay(self) -> ReplayableTrace: events=list(self._trace_events), state_snapshots=[], config={"capture_config": self._capture_config.model_dump()}, + metadata={"memory_snapshot": self.memory_snapshot_dict()}, ) # --- Framework Integration --- @@ -186,7 +192,16 @@ async def traced_invoke(*args: Any, **kwargs: Any) -> Any: raise finally: output = adapter._safe_serialize(results[-1]) if results else None - adapter.on_run_end(agent_name=agent_name, output=output, error=error) + # Roll up the tool names seen across the streamed turn for + # the memory recorder's procedural-pattern detector. + tool_names = adapter._collect_tool_names_from_messages(results) + adapter.on_run_end( + agent_name=agent_name, + output=output, + error=error, + input_data=input_data, + tools=tool_names, + ) traced_invoke._layerlens_original = original_invoke # type: ignore[attr-defined] return traced_invoke @@ -199,19 +214,29 @@ async def traced_invoke_stream(*args: Any, **kwargs: Any) -> Any: chat_name = getattr(chat, "name", None) or "ms_agent_chat" agent = kwargs.get("agent") or (args[0] if args else None) agent_name = getattr(agent, "name", None) or chat_name if agent else chat_name - adapter.on_run_start(agent_name=agent_name, input_data=None) + input_data = kwargs.get("input") or kwargs.get("message") + adapter.on_run_start(agent_name=agent_name, input_data=input_data) error: Exception | None = None last_message = None + seen: list[Any] = [] try: async for message in original_invoke_stream(*args, **kwargs): last_message = message + seen.append(message) yield message except Exception as exc: error = exc raise finally: output = adapter._safe_serialize(last_message) if last_message else None - adapter.on_run_end(agent_name=agent_name, output=output, error=error) + tool_names = adapter._collect_tool_names_from_messages(seen) + adapter.on_run_end( + agent_name=agent_name, + output=output, + error=error, + input_data=input_data, + tools=tool_names, + ) traced_invoke_stream._layerlens_original = original_invoke_stream # type: ignore[attr-defined] return traced_invoke_stream @@ -311,6 +336,8 @@ def on_run_end( agent_name: str | None = None, output: Any = None, error: Exception | None = None, + input_data: Any = None, + tools: list[str] | None = None, ) -> None: """Emit agent.output event when a chat invocation ends.""" if not self._connected: @@ -338,6 +365,18 @@ def on_run_end( "event_subtype": "run_complete" if not error else "run_failed", }, ) + # Cross-poll #1: persist the chat turn into memory. + # Episodic = the (input, final-output) pair; procedural + # patterns build up across chat turns within a recorder + # lifetime. Tools, when collected from the stream, feed + # the procedural-pattern detector. + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + error=str(error) if error else None, + tools=tools, + ) except Exception: logger.warning("Error in on_run_end", exc_info=True) @@ -490,3 +529,23 @@ def _safe_serialize(self, value: Any) -> Any: return str(value) except Exception: return str(value) + + def _collect_tool_names_from_messages(self, messages: list[Any]) -> list[str]: + """Collect distinct tool names invoked across a streamed turn. + + Used by the memory recorder's procedural-pattern detector — only + the names are needed, not the full arguments. + """ + names: list[str] = [] + try: + for message in messages: + items = getattr(message, "items", None) or [] + for item in items: + item_type = type(item).__name__ + if "FunctionCall" in item_type or "ToolCall" in item_type: + name = getattr(item, "name", None) or getattr(item, "function_name", None) + if name: + names.append(str(name)) + except Exception: + logger.debug("Could not collect tool names from messages", exc_info=True) + return names diff --git a/src/layerlens/instrument/adapters/frameworks/openai_agents/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/openai_agents/lifecycle.py index a1cee4b..ed24566 100644 --- a/src/layerlens/instrument/adapters/frameworks/openai_agents/lifecycle.py +++ b/src/layerlens/instrument/adapters/frameworks/openai_agents/lifecycle.py @@ -63,6 +63,16 @@ def __init__( self._framework_version: str | None = None self._trace_processor: Any | None = None self._run_starts: dict[int, int] = {} # thread_id -> start_ns + # Per-span input cache for memory persistence — _on_agent_span_start + # captures the agent input keyed by span_id so _on_agent_span_end + # can pair it with the output when calling record_memory_turn(). + # Bounded by the max-concurrent-spans of the SDK runtime; cleared + # in disconnect(). + self._span_inputs: dict[str, Any] = {} + # Per-span tool list for procedural-pattern detection. Tool spans + # arrive between agent_start and agent_end; we accumulate them + # under the parent agent span_id when available. + self._span_tools: dict[str, list[str]] = {} def connect(self) -> None: """Import openai-agents SDK and register trace processor.""" @@ -82,6 +92,8 @@ def disconnect(self) -> None: # _connected guard in emit_dict_event instead. self._trace_processor = None self._seen_agents.clear() + self._span_inputs.clear() + self._span_tools.clear() self._connected = False self._status = AdapterStatus.DISCONNECTED @@ -111,6 +123,13 @@ def get_adapter_info(self) -> AdapterInfo: ) def serialize_for_replay(self) -> ReplayableTrace: + """Serialize the current trace data for replay. + + Includes the per-adapter memory snapshot (cross-poll #1) under + ``metadata["memory_snapshot"]``. Combined with the SDK + ``Session``-style memory primitive, this lets a replay engine + re-hydrate the agent's recall state before re-execution. + """ return ReplayableTrace( adapter_name="OpenAIAgentsAdapter", framework=self.FRAMEWORK, @@ -118,6 +137,7 @@ def serialize_for_replay(self) -> ReplayableTrace: events=list(self._trace_events), state_snapshots=[], config={"capture_config": self._capture_config.model_dump()}, + metadata={"memory_snapshot": self.memory_snapshot_dict()}, ) # --- Framework Integration --- @@ -254,13 +274,27 @@ def _on_span_end(self, span: Any) -> None: def _on_agent_span_start(self, span: Any, data: Any) -> None: agent_name = getattr(data, "name", None) or "unknown" + span_id = getattr(span, "span_id", None) self._emit_agent_config(agent_name, data) + # Stash the input under span_id so _on_agent_span_end can pair + # it with the output for the memory recorder. The SDK's + # AgentSpanData carries the input in attributes that vary by + # SDK version — try the known keys. + if span_id: + input_data = ( + getattr(data, "input", None) + or getattr(data, "input_messages", None) + or getattr(data, "messages", None) + ) + with self._adapter_lock: + self._span_inputs[str(span_id)] = input_data + self._span_tools.setdefault(str(span_id), []) self.emit_dict_event( "agent.input", { "framework": "openai_agents", "agent_name": agent_name, - "span_id": getattr(span, "span_id", None), + "span_id": span_id, "timestamp_ns": time.time_ns(), }, ) @@ -268,15 +302,31 @@ def _on_agent_span_start(self, span: Any, data: Any) -> None: def _on_agent_span_end(self, span: Any, data: Any) -> None: agent_name = getattr(data, "name", None) or "unknown" output = getattr(data, "output", None) + span_id = getattr(span, "span_id", None) self.emit_dict_event( "agent.output", { "framework": "openai_agents", "agent_name": agent_name, "output": self._safe_serialize(output), - "span_id": getattr(span, "span_id", None), + "span_id": span_id, }, ) + # Cross-poll #1: persist this turn into memory. Pair the + # cached input with the output, and roll up the tool span + # names captured between start and end. + input_data: Any = None + tool_names: list[str] = [] + if span_id: + with self._adapter_lock: + input_data = self._span_inputs.pop(str(span_id), None) + tool_names = self._span_tools.pop(str(span_id), []) + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + tools=tool_names or None, + ) def _on_generation_span_end(self, span: Any, data: Any) -> None: payload: dict[str, Any] = {"framework": "openai_agents"} @@ -317,6 +367,14 @@ def _on_function_span_end(self, span: Any, data: Any) -> None: "latency_ms": getattr(span, "duration_ms", None), }, ) + # Memory persistence: roll the tool name up to the parent agent + # span so the agent-level record_memory_turn() at span_end has + # the full tool list for procedural-pattern detection. + parent_id = getattr(span, "parent_id", None) or getattr(span, "parent_span_id", None) + if parent_id: + with self._adapter_lock: + bucket = self._span_tools.setdefault(str(parent_id), []) + bucket.append(str(tool_name)) def _on_handoff_span_start(self, span: Any, data: Any) -> None: pass # Start event captured on end for complete data @@ -374,6 +432,8 @@ def on_run_end( agent_name: str | None = None, output: Any = None, error: Exception | None = None, + input_data: Any = None, + tools: list[str] | None = None, ) -> None: if not self._connected: return @@ -392,6 +452,16 @@ def on_run_end( if error: payload["error"] = str(error) self.emit_dict_event("agent.output", payload) + # Cross-poll #1: parallel memory persistence path for the + # Runner-wrapping integration (the TraceProcessor path + # records inside _on_agent_span_end). + self.record_memory_turn( + agent_name=agent_name, + input_data=self._safe_serialize(input_data), + output_data=self._safe_serialize(output), + error=str(error) if error else None, + tools=tools, + ) except Exception: logger.warning("Error in on_run_end", exc_info=True) diff --git a/tests/instrument/adapters/_base/test_memory.py b/tests/instrument/adapters/_base/test_memory.py new file mode 100644 index 0000000..2f65274 --- /dev/null +++ b/tests/instrument/adapters/_base/test_memory.py @@ -0,0 +1,465 @@ +"""Tests for the shared memory persistence module. + +Pins the contract documented in +``src/layerlens/instrument/adapters/_base/memory.py``: + +* :class:`MemorySnapshot` is content-addressable and immutable. +* :meth:`MemoryRecorder.snapshot` → :meth:`MemoryRecorder.restore` + round-trips byte-exactly (replay safety). +* Bounded buffers evict deterministically. +* Multi-tenant isolation: tenant A's snapshot cannot be restored into + tenant B's recorder. +* Thread-safe accumulation under concurrent ``record_turn`` callers. +* SHA-256 content hash is deterministic across processes / Python + invocations (canonical JSON encoding). + +Cross-pollination audit reference: +``A:/tmp/adapter-cross-pollination-audit.md`` §2.1 (memory persistence). +""" + +from __future__ import annotations + +import threading +from typing import List + +import pytest + +from layerlens.instrument.adapters._base.memory import ( + DEFAULT_MAX_EPISODIC, + DEFAULT_MAX_SEMANTIC, + DEFAULT_MAX_PROCEDURAL, + MemoryRecorder, + MemorySnapshot, +) + +# --------------------------------------------------------------------------- +# Construction-time contract +# --------------------------------------------------------------------------- + + +def test_recorder_requires_non_empty_org_id() -> None: + """Empty / whitespace org_id is rejected (multi-tenancy fail-fast).""" + with pytest.raises(ValueError, match="non-empty org_id"): + MemoryRecorder(org_id="") + with pytest.raises(ValueError, match="non-empty org_id"): + MemoryRecorder(org_id=" ") + with pytest.raises(ValueError, match="non-empty org_id"): + MemoryRecorder(org_id=" \t\n") + + +def test_recorder_rejects_non_string_org_id() -> None: + """A non-string org_id is rejected at construction time.""" + with pytest.raises(ValueError, match="non-empty org_id"): + MemoryRecorder(org_id=None) # type: ignore[arg-type] + + +def test_recorder_rejects_zero_buffer_sizes() -> None: + """Bounded buffers must allow at least one entry.""" + with pytest.raises(ValueError, match="bounded buffer"): + MemoryRecorder(org_id="org-x", max_episodic=0) + with pytest.raises(ValueError, match="bounded buffer"): + MemoryRecorder(org_id="org-x", max_procedural=0) + with pytest.raises(ValueError, match="bounded buffer"): + MemoryRecorder(org_id="org-x", max_semantic=0) + + +def test_recorder_initial_state_is_empty() -> None: + """A fresh recorder snapshots to all-empty buckets at turn 0.""" + rec = MemoryRecorder(org_id="org-init") + snap = rec.snapshot() + assert snap.turn_index == 0 + assert snap.episodic == [] + assert snap.procedural == [] + assert snap.semantic == {} + assert snap.org_id == "org-init" + # Hash is non-empty and deterministic for the empty state. + assert isinstance(snap.content_hash, str) + assert len(snap.content_hash) == 64 # SHA-256 hex digest length + + +# --------------------------------------------------------------------------- +# Snapshot determinism (content-hash invariant) +# --------------------------------------------------------------------------- + + +def test_snapshot_hash_deterministic_for_identical_content() -> None: + """Two recorders with identical inputs produce identical content hashes.""" + rec_a = MemoryRecorder(org_id="org-deterministic") + rec_b = MemoryRecorder(org_id="org-deterministic") + + for rec in (rec_a, rec_b): + rec.record_turn(agent_name="researcher", input_data="hi", output_data="hello") + rec.record_turn(agent_name="writer", input_data="topic", output_data="draft", tools=["search", "write"]) + rec.set_semantic("user_pref:lang", "en-US") + + snap_a = rec_a.snapshot() + snap_b = rec_b.snapshot() + + assert snap_a.content_hash == snap_b.content_hash + assert snap_a.turn_index == snap_b.turn_index == 2 + + +def test_snapshot_hash_changes_when_org_id_differs() -> None: + """Same content but different tenant → different hash (no collision).""" + rec_a = MemoryRecorder(org_id="org-A") + rec_b = MemoryRecorder(org_id="org-B") + + for rec in (rec_a, rec_b): + rec.record_turn(agent_name="x", input_data="hi", output_data="ok") + + assert rec_a.snapshot().content_hash != rec_b.snapshot().content_hash + + +def test_snapshot_hash_changes_when_episodic_changes() -> None: + """Adding a turn changes the snapshot hash.""" + rec = MemoryRecorder(org_id="org-x") + h0 = rec.snapshot().content_hash + rec.record_turn(agent_name="a", input_data="i", output_data="o") + h1 = rec.snapshot().content_hash + assert h0 != h1 + + +def test_snapshot_immutability_after_recorder_mutation() -> None: + """A previously-returned snapshot is unaffected by later recorder mutation.""" + rec = MemoryRecorder(org_id="org-immut") + rec.record_turn(agent_name="x", input_data="hi", output_data="ok") + snap_before = rec.snapshot() + hash_before = snap_before.content_hash + episodic_len_before = len(snap_before.episodic) + + # Mutate the recorder. + rec.record_turn(agent_name="x", input_data="more", output_data="data") + rec.set_semantic("key", "value") + + # Original snapshot is unchanged. + assert snap_before.content_hash == hash_before + assert len(snap_before.episodic) == episodic_len_before + assert "key" not in snap_before.semantic + + +def test_snapshot_dict_roundtrip_preserves_hash() -> None: + """to_dict() / from_dict() round-trip preserves the snapshot identity.""" + rec = MemoryRecorder(org_id="org-roundtrip") + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=["t1", "t2"]) + rec.record_turn(agent_name="b", input_data="i2", output_data="o2", tools=["t2", "t3"]) + rec.set_semantic("k", "v") + + snap = rec.snapshot() + serialised = snap.to_dict() + restored = MemorySnapshot.from_dict(serialised) + + assert restored.content_hash == snap.content_hash + assert restored.turn_index == snap.turn_index + assert restored.episodic == snap.episodic + assert restored.semantic == snap.semantic + assert restored.org_id == snap.org_id + + +def test_snapshot_from_dict_rejects_missing_field() -> None: + """from_dict raises when a required field is missing.""" + with pytest.raises(ValueError, match="missing required field"): + MemorySnapshot.from_dict({"turn_index": 0, "episodic": [], "procedural": [], "semantic": {}}) + + +# --------------------------------------------------------------------------- +# Replay round-trip (the core determinism guarantee) +# --------------------------------------------------------------------------- + + +def test_restore_reproduces_byte_exact_snapshot() -> None: + """snapshot() → restore() → snapshot() yields identical content_hash.""" + src = MemoryRecorder(org_id="org-replay") + for i in range(5): + src.record_turn(agent_name=f"agent-{i % 2}", input_data=f"in-{i}", output_data=f"out-{i}") + src.set_semantic("session_summary", "user asked about pricing tiers") + snap = src.snapshot() + + # Fresh recorder, same tenant, restore from the snapshot, snapshot again. + target = MemoryRecorder(org_id="org-replay") + target.restore(snap) + restored_snap = target.snapshot() + + assert restored_snap.content_hash == snap.content_hash + assert restored_snap.turn_index == snap.turn_index + assert restored_snap.episodic == snap.episodic + assert restored_snap.procedural == snap.procedural + assert restored_snap.semantic == snap.semantic + + +def test_restore_then_record_yields_deterministic_next_state() -> None: + """After restoring identical snapshots into two recorders, the same + next-turn produces identical next-snapshots (the replay-safety + contract). + + Note: ``record_turn`` stamps a wall-clock ``timestamp_ns`` into the + episodic entry, so two recorders run at different wall-clock times + will drift in the ``timestamp_ns`` field of the *new* turn. That + timestamp is part of the documented turn shape and intentionally + ingested into the hash — replay engines suppress this drift by + capturing the original ``timestamp_ns`` from the source trace and + using it to seed the recorder's clock at restore time. The test + here proves the deterministic-everything-else contract: aside + from ``timestamp_ns``, the full state is byte-identical.""" + src = MemoryRecorder(org_id="org-det") + src.record_turn(agent_name="a", input_data="i", output_data="o") + snap = src.snapshot() + + rec_x = MemoryRecorder(org_id="org-det") + rec_y = MemoryRecorder(org_id="org-det") + rec_x.restore(snap) + rec_y.restore(snap) + + rec_x.record_turn(agent_name="b", input_data="i2", output_data="o2", tools=["t1"]) + rec_y.record_turn(agent_name="b", input_data="i2", output_data="o2", tools=["t1"]) + + snap_x = rec_x.snapshot() + snap_y = rec_y.snapshot() + # All-but-timestamp identity. + assert snap_x.turn_index == snap_y.turn_index + assert snap_x.semantic == snap_y.semantic + assert snap_x.procedural == snap_y.procedural + assert len(snap_x.episodic) == len(snap_y.episodic) + for ex, ey in zip(snap_x.episodic, snap_y.episodic): + assert ex["agent_name"] == ey["agent_name"] + assert ex["input"] == ey["input"] + assert ex["output"] == ey["output"] + assert ex.get("tools") == ey.get("tools") + assert ex["turn_index"] == ey["turn_index"] + + +def test_restore_rejects_cross_tenant_snapshot() -> None: + """A snapshot from tenant A cannot be restored into a tenant-B recorder.""" + rec_a = MemoryRecorder(org_id="org-A") + rec_a.record_turn(agent_name="x", input_data="hi", output_data="ok") + snap_a = rec_a.snapshot() + + rec_b = MemoryRecorder(org_id="org-B") + with pytest.raises(ValueError, match="Cross-tenant restore is prohibited"): + rec_b.restore(snap_a) + + +def test_restore_rejects_tampered_snapshot() -> None: + """A snapshot whose content_hash does not match its content is rejected.""" + rec = MemoryRecorder(org_id="org-tamper") + rec.record_turn(agent_name="x", input_data="hi", output_data="ok") + snap = rec.snapshot() + + # Build a tampered snapshot: same hash, mutated semantic content. + tampered = MemorySnapshot( + turn_index=snap.turn_index, + episodic=list(snap.episodic), + procedural=list(snap.procedural), + semantic={"injected": "evil"}, + content_hash=snap.content_hash, # Stale — does not cover the new semantic dict. + org_id=snap.org_id, + ) + target = MemoryRecorder(org_id="org-tamper") + with pytest.raises(ValueError, match="content_hash mismatch"): + target.restore(tampered) + + +# --------------------------------------------------------------------------- +# Bounded-buffer eviction +# --------------------------------------------------------------------------- + + +def test_episodic_buffer_evicts_oldest_fifo_at_cap() -> None: + """Episodic buffer drops oldest turns when ``max_episodic`` is exceeded.""" + rec = MemoryRecorder(org_id="org-evict", max_episodic=3) + for i in range(5): + rec.record_turn(agent_name="x", input_data=f"in-{i}", output_data=f"out-{i}") + + snap = rec.snapshot() + assert len(snap.episodic) == 3 + # Oldest two were dropped — turns 1 & 2. Surviving turns should be 3, 4, 5. + assert [t["turn_index"] for t in snap.episodic] == [3, 4, 5] + # Turn counter still monotonic — eviction does NOT roll back the counter. + assert snap.turn_index == 5 + + +def test_semantic_store_evicts_least_recently_set_at_cap() -> None: + """Semantic store evicts the oldest-set entry when over cap.""" + rec = MemoryRecorder(org_id="org-sem", max_semantic=2) + rec.set_semantic("k1", "v1") + rec.set_semantic("k2", "v2") + rec.set_semantic("k3", "v3") # Should evict k1. + + snap = rec.snapshot() + assert "k1" not in snap.semantic + assert snap.semantic == {"k2": "v2", "k3": "v3"} + + +def test_semantic_overwrite_refreshes_lru_position() -> None: + """Setting an existing key moves it to most-recent position.""" + rec = MemoryRecorder(org_id="org-sem-lru", max_semantic=2) + rec.set_semantic("k1", "v1") + rec.set_semantic("k2", "v2") + # Refresh k1 → now k2 is the oldest. + rec.set_semantic("k1", "v1-updated") + rec.set_semantic("k3", "v3") # Should evict k2. + + snap = rec.snapshot() + assert "k2" not in snap.semantic + assert snap.semantic == {"k1": "v1-updated", "k3": "v3"} + + +def test_semantic_set_rejects_empty_key() -> None: + """An empty / whitespace key is rejected.""" + rec = MemoryRecorder(org_id="org-x") + with pytest.raises(ValueError, match="non-empty key"): + rec.set_semantic("", "v") + with pytest.raises(ValueError, match="non-empty key"): + rec.set_semantic(" ", "v") + + +def test_procedural_buffer_caps_distinct_patterns() -> None: + """Procedural store is bounded by ``max_procedural``.""" + rec = MemoryRecorder(org_id="org-proc", max_procedural=2) + # Generate >2 distinct procedural patterns — recurring tool sequences. + for cycle in range(3): + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=[f"t{cycle}-A"]) + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=[f"t{cycle}-B"]) + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=[f"t{cycle}-A"]) + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=[f"t{cycle}-B"]) + + snap = rec.snapshot() + assert len(snap.procedural) <= 2 + + +# --------------------------------------------------------------------------- +# Procedural-pattern detection +# --------------------------------------------------------------------------- + + +def test_procedural_pattern_recurrence_increments_count() -> None: + """Repeated tool sequences accumulate ``count``.""" + rec = MemoryRecorder(org_id="org-pat") + # search → write happens twice in a row. + for _ in range(3): + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=["search"]) + rec.record_turn(agent_name="a", input_data="i", output_data="o", tools=["write"]) + + snap = rec.snapshot() + # We expect at least one pattern with count >= 2. + assert snap.procedural, "expected at least one procedural pattern" + counts = [p["count"] for p in snap.procedural] + assert any(c >= 2 for c in counts), f"no recurring pattern detected; got {snap.procedural}" + + +def test_procedural_pattern_skips_turns_with_no_tools() -> None: + """Two consecutive tool-less turns produce no procedural pattern.""" + rec = MemoryRecorder(org_id="org-no-tools") + rec.record_turn(agent_name="a", input_data="i", output_data="o") + rec.record_turn(agent_name="b", input_data="i", output_data="o") + snap = rec.snapshot() + assert snap.procedural == [] + + +# --------------------------------------------------------------------------- +# Per-turn truncation (defence-in-depth, not a policy substitute) +# --------------------------------------------------------------------------- + + +def test_oversized_turn_value_is_capped() -> None: + """A multi-megabyte string in a turn is hard-capped to prevent overflow.""" + rec = MemoryRecorder(org_id="org-cap") + huge = "x" * 100_000 # 100 KB string. + rec.record_turn(agent_name="a", input_data=huge, output_data=huge) + + snap = rec.snapshot() + captured = snap.episodic[0] + assert "<...truncated:orig_len=100000>" in captured["input"] + assert "<...truncated:orig_len=100000>" in captured["output"] + + +# --------------------------------------------------------------------------- +# Thread safety +# --------------------------------------------------------------------------- + + +def test_concurrent_record_turn_calls_serialise_correctly() -> None: + """Many threads recording turns simultaneously produce a consistent snapshot.""" + rec = MemoryRecorder(org_id="org-thread", max_episodic=10_000) + + def worker(start: int) -> None: + for i in range(50): + rec.record_turn( + agent_name=f"t-{start}", + input_data=f"i-{start}-{i}", + output_data=f"o-{start}-{i}", + ) + + threads: List[threading.Thread] = [threading.Thread(target=worker, args=(n,)) for n in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + snap = rec.snapshot() + assert snap.turn_index == 8 * 50 + assert len(snap.episodic) == 8 * 50 + # Every turn_index from 1..400 appears exactly once (no duplicates, + # no gaps) — the lock guarantees serialisation. + indices = sorted(t["turn_index"] for t in snap.episodic) + assert indices == list(range(1, 8 * 50 + 1)) + + +# --------------------------------------------------------------------------- +# Clear / reset +# --------------------------------------------------------------------------- + + +def test_clear_resets_state_but_preserves_org_binding() -> None: + """``clear`` returns the recorder to empty state without releasing the tenant.""" + rec = MemoryRecorder(org_id="org-clear") + rec.record_turn(agent_name="x", input_data="i", output_data="o") + rec.set_semantic("k", "v") + + rec.clear() + + snap = rec.snapshot() + assert snap.turn_index == 0 + assert snap.episodic == [] + assert snap.semantic == {} + assert snap.org_id == "org-clear" # Binding preserved. + + +# --------------------------------------------------------------------------- +# Default constants surface +# --------------------------------------------------------------------------- + + +def test_default_constants_are_positive() -> None: + """Documented defaults are sensible non-zero values.""" + assert DEFAULT_MAX_EPISODIC > 0 + assert DEFAULT_MAX_PROCEDURAL > 0 + assert DEFAULT_MAX_SEMANTIC > 0 + + +# --------------------------------------------------------------------------- +# Empty episodic / extra metadata edge cases +# --------------------------------------------------------------------------- + + +def test_extra_metadata_is_sorted_for_determinism() -> None: + """Two callers passing the same ``extra`` dict in different key orders + produce the same hash — the recorder sorts ``extra`` keys.""" + rec_a = MemoryRecorder(org_id="org-ex") + rec_b = MemoryRecorder(org_id="org-ex") + + # Build dicts with deliberately different insertion order. + extra_x = {"zeta": 1, "alpha": 2, "mu": 3} + extra_y = {"alpha": 2, "mu": 3, "zeta": 1} + + rec_a.record_turn(agent_name="x", input_data="i", output_data="o", extra=extra_x) + rec_b.record_turn(agent_name="x", input_data="i", output_data="o", extra=extra_y) + + assert rec_a.snapshot().content_hash == rec_b.snapshot().content_hash + + +def test_record_turn_returns_new_turn_index() -> None: + """``record_turn`` returns the post-increment counter (caller convenience).""" + rec = MemoryRecorder(org_id="org-ret") + assert rec.record_turn(agent_name="x", input_data="i", output_data="o") == 1 + assert rec.record_turn(agent_name="x", input_data="i", output_data="o") == 2 + assert rec.turn_index == 2 diff --git a/tests/instrument/adapters/frameworks/test_memory_persistence_wiring.py b/tests/instrument/adapters/frameworks/test_memory_persistence_wiring.py new file mode 100644 index 0000000..c238910 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_memory_persistence_wiring.py @@ -0,0 +1,183 @@ +"""Per-adapter memory persistence wiring smoke tests. + +For each of the six target adapters wired in this PR (cross-poll #1), +verify that: + +* The adapter inherits a per-instance :class:`MemoryRecorder` from + :class:`BaseAdapter`. +* The recorder is bound to the adapter's ``org_id`` (multi-tenancy). +* :meth:`serialize_for_replay` includes a content-addressable + ``memory_snapshot`` dict in the :class:`ReplayableTrace` metadata. +* The snapshot is restorable into a fresh recorder (replay round-trip). + +These tests exercise the wiring contract — the deeper behavioural +unit tests for the recorder itself live in +``tests/instrument/adapters/_base/test_memory.py`` (27 tests). + +Browser_use is intentionally **not** included — that adapter is not +present on this PR's base branch (see +``docs/adapters/memory-contract.md`` "Honest scope disclosure"). When +the histories merge, this test module should be extended with a +``BrowserUseAdapter`` parametrize entry. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Type + +import pytest + +from layerlens.instrument.adapters._base import ( + BaseAdapter, + MemoryRecorder, + MemorySnapshot, +) + + +class _RecordingStratix: + """Minimal stratix stand-in carrying a tenant binding.""" + + org_id: str = "test-org-mem" + + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def _adapter_classes() -> List[Type[BaseAdapter]]: + """Return the list of target adapter classes for this PR. + + Imported lazily and individually so a missing adapter on the base + branch (e.g. ``browser_use``) is reported as a clean skip rather + than a collection error. + """ + classes: List[Type[BaseAdapter]] = [] + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + from layerlens.instrument.adapters.frameworks.google_adk import GoogleADKAdapter + from layerlens.instrument.adapters.frameworks.llama_index import LlamaIndexAdapter + from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter + from layerlens.instrument.adapters.frameworks.bedrock_agents import BedrockAgentsAdapter + from layerlens.instrument.adapters.frameworks.ms_agent_framework import MSAgentAdapter + + classes.extend([ + AgnoAdapter, + BedrockAgentsAdapter, + GoogleADKAdapter, + LlamaIndexAdapter, + MSAgentAdapter, + OpenAIAgentsAdapter, + ]) + return classes + + +@pytest.fixture(params=_adapter_classes(), ids=lambda c: c.__name__) +def adapter_cls(request: pytest.FixtureRequest) -> Type[BaseAdapter]: + return request.param # type: ignore[no-any-return] + + +def test_adapter_owns_memory_recorder_bound_to_org(adapter_cls: Type[BaseAdapter]) -> None: + """Every wired adapter exposes a recorder bound to its tenant.""" + stratix = _RecordingStratix() + adapter = adapter_cls(stratix=stratix) + + assert isinstance(adapter.memory_recorder, MemoryRecorder) + assert adapter.memory_recorder.org_id == "test-org-mem" + assert adapter.org_id == adapter.memory_recorder.org_id + + +def test_record_memory_turn_advances_episodic_buffer(adapter_cls: Type[BaseAdapter]) -> None: + """The BaseAdapter helper feeds the recorder.""" + stratix = _RecordingStratix() + adapter = adapter_cls(stratix=stratix) + initial = adapter.memory_recorder.snapshot() + assert initial.turn_index == 0 + assert initial.episodic == [] + + adapter.record_memory_turn( + agent_name="test-agent", + input_data="hello", + output_data="world", + tools=["search"], + ) + + snap = adapter.memory_recorder.snapshot() + assert snap.turn_index == 1 + assert len(snap.episodic) == 1 + assert snap.episodic[0]["agent_name"] == "test-agent" + assert snap.episodic[0]["input"] == "hello" + assert snap.episodic[0]["output"] == "world" + assert snap.episodic[0]["tools"] == ["search"] + + +def test_serialize_for_replay_embeds_memory_snapshot(adapter_cls: Type[BaseAdapter]) -> None: + """Each adapter ships ``metadata["memory_snapshot"]`` in the replay trace.""" + stratix = _RecordingStratix() + adapter = adapter_cls(stratix=stratix) + adapter.record_memory_turn( + agent_name="agent", + input_data="i", + output_data="o", + tools=["t"], + ) + + trace = adapter.serialize_for_replay() + assert "memory_snapshot" in trace.metadata, ( + f"{adapter_cls.__name__}.serialize_for_replay() must embed " + "metadata['memory_snapshot'] for replay-safe memory restoration" + ) + + snapshot_dict = trace.metadata["memory_snapshot"] + assert isinstance(snapshot_dict, dict) + # The dict must round-trip into a MemorySnapshot. + snapshot = MemorySnapshot.from_dict(snapshot_dict) + assert snapshot.turn_index == 1 + assert snapshot.org_id == "test-org-mem" + assert len(snapshot.episodic) == 1 + + +def test_replay_engine_can_restore_recorder_from_serialized_trace( + adapter_cls: Type[BaseAdapter], +) -> None: + """Replay-safety smoke: serialise → from_dict → restore → snapshot match.""" + stratix = _RecordingStratix() + src_adapter = adapter_cls(stratix=stratix) + for i in range(3): + src_adapter.record_memory_turn( + agent_name="a", + input_data=f"in-{i}", + output_data=f"out-{i}", + tools=["t1"], + ) + src_adapter.memory_recorder.set_semantic("session_summary", "user asked about pricing") + src_trace = src_adapter.serialize_for_replay() + src_snapshot = src_adapter.memory_recorder.snapshot() + + # Replay-side: fresh adapter + restore from the trace. + replay_adapter = adapter_cls(stratix=_RecordingStratix()) + snapshot = MemorySnapshot.from_dict(src_trace.metadata["memory_snapshot"]) + replay_adapter.memory_recorder.restore(snapshot) + restored_snapshot = replay_adapter.memory_recorder.snapshot() + + assert restored_snapshot.content_hash == src_snapshot.content_hash + assert restored_snapshot.turn_index == src_snapshot.turn_index + assert restored_snapshot.episodic == src_snapshot.episodic + assert restored_snapshot.semantic == src_snapshot.semantic + + +def test_recorder_rejects_cross_tenant_snapshot_via_adapter(adapter_cls: Type[BaseAdapter]) -> None: + """Tenant-A adapter cannot accept a tenant-B snapshot at the recorder boundary.""" + stratix_a = _RecordingStratix() + stratix_a.org_id = "tenant-A" + adapter_a = adapter_cls(stratix=stratix_a) + adapter_a.record_memory_turn(agent_name="x", input_data="i", output_data="o") + snap_a = adapter_a.memory_recorder.snapshot() + + stratix_b = _RecordingStratix() + stratix_b.org_id = "tenant-B" + adapter_b = adapter_cls(stratix=stratix_b) + + with pytest.raises(ValueError, match="Cross-tenant restore is prohibited"): + adapter_b.memory_recorder.restore(snap_a)