From a7a4e992abafc8979dd5148abd8f90f81dd1aa3a Mon Sep 17 00:00:00 2001 From: mmercuri Date: Sun, 26 Apr 2026 17:11:38 -0700 Subject: [PATCH 1/2] feat(instrument): per-callback try/except resilience wrapper across 10 lighter adapters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a shared @resilient_callback decorator + ResilienceTracker under `src/layerlens/instrument/adapters/_base/`, then applies it to every callback method on the 10 lighter framework adapters (agno, llamaindex, google_adk, strands, pydantic_ai, smolagents, bedrock_agents, openai_agents, haystack, langfuse) so an exception in our observability code can never crash the customer's framework execution. What the decorator does on failure: 1. Catches Exception (NOT BaseException — KeyboardInterrupt / SystemExit still propagate so users can Ctrl-C their agent). 2. Logs the exception via the wrapped function's module logger with adapter_name + callback_name + truncated traceback. 3. Increments the adapter's per-instance ResilienceTracker counter. 4. Returns the framework's expected default value — None for void handlers, or the value of `passthrough_arg` for mutating hooks (Pydantic-AI's `after_model_request` returns the response object; `before_tool_execute` returns the args tuple). Health surfacing: - FrameworkAdapter now owns a `_resilience: ResilienceTracker` attribute set in `__init__` so every framework adapter inherits the contract. - `adapter_info().metadata` merges the live resilience snapshot (`resilience_status`, `resilience_failures_total`, `resilience_failure_threshold`, per-callback breakdown, last error). - After DEFAULT_FAILURE_THRESHOLD (5) failures the adapter reports `resilience_status: "degraded"` so monitoring can alert. - `disconnect()` resets the tracker so reconnects start clean. Per-adapter callback audit + fixes: | Adapter | Callbacks wrapped | Notes | |-----------------|-------------------|----------------------------------------| | agno | 2 | _on_run_start, _on_run_end | | llamaindex | 16 | 3 span lifecycle + dispatcher + 12 events| | google_adk | 11 | All adapter _on_* + simplified plugin shims| | strands | 7 | All hook handlers (replaces manual try/except)| | pydantic_ai | 9 (incl 3 split) | Error hooks split: telemetry resilient, re-raise unconditional| | smolagents | 6 | Run/step handlers (replaces manual try/except)| | bedrock_agents | 2 | _before_invoke + _after_invoke (with try/finally for _end_run)| | openai_agents | 3 | on_trace_start/_end + on_span_end (replaces manual try/except)| | haystack | 1 | _on_span_end (replaces manual try/except)| | langfuse | 5 | _import_single_trace, _import_observation, _import_score, _export_single_trace, plus inner emit fallbacks| | TOTAL | 62 | | Pydantic-AI error-callback split: `_on_run_error`, `_on_model_request_error`, `_on_tool_execute_error` MUST always re-raise the framework's original error (per Pydantic-AI's contract). The telemetry side is moved into a `_emit_*_error_telemetry` helper wrapped with @resilient_callback; the public hook calls it then unconditionally `raise error`. So adapter-side telemetry bugs can never swallow a real framework error. Tests: - `tests/instrument/adapters/_base/test_resilience.py` — 34 tests covering tracker mechanics, decorator behaviour, passthrough args, KeyboardInterrupt propagation, FrameworkAdapter integration, package re-exports, and decorator metadata preservation. - `tests/instrument/adapters/_base/test_per_adapter_resilience.py` — per-adapter smoke tests (one per lighter adapter) that simulate a callback exception by sabotaging an inner helper, plus a parametrized health-degradation test across all 10 adapters. Refactor: `_base.py` (the AdapterInfo + BaseAdapter module) becomes `_base/` package with `__init__.py` re-exporting from `_core.py` (moved via `git mv`) and the new `resilience.py`. All existing `from .._base import AdapterInfo, BaseAdapter` imports continue working unchanged. Acceptance: - pytest tests/instrument/adapters/_base/test_resilience.py -x — 34 passed - pytest tests/instrument/adapters/frameworks/ -x — 146 passed (12 skipped for missing optional deps; 2 deselected pre-existing Windows clock-resolution flakes in test_haystack) - mypy --strict src/layerlens/instrument/adapters/_base/resilience.py — Success - mypy src — Success: 169 source files - ruff check — All checks passed - Full test suite: 1090 passed --- .../instrument/adapters/_base/__init__.py | 28 + .../adapters/{_base.py => _base/_core.py} | 0 .../instrument/adapters/_base/resilience.py | 394 +++++++++++++ .../adapters/frameworks/_base_framework.py | 25 +- .../instrument/adapters/frameworks/agno.py | 3 + .../adapters/frameworks/bedrock_agents.py | 91 +-- .../adapters/frameworks/google_adk.py | 74 +-- .../adapters/frameworks/haystack.py | 13 +- .../adapters/frameworks/langfuse.py | 124 ++-- .../adapters/frameworks/llamaindex.py | 36 +- .../adapters/frameworks/openai_agents.py | 69 ++- .../adapters/frameworks/pydantic_ai.py | 49 +- .../adapters/frameworks/smolagents.py | 17 +- .../instrument/adapters/frameworks/strands.py | 279 +++++---- tests/instrument/adapters/_base/__init__.py | 0 .../_base/test_per_adapter_resilience.py | 537 ++++++++++++++++++ .../adapters/_base/test_resilience.py | 500 ++++++++++++++++ .../adapters/frameworks/test_langfuse.py | 12 +- 18 files changed, 1895 insertions(+), 356 deletions(-) create mode 100644 src/layerlens/instrument/adapters/_base/__init__.py rename src/layerlens/instrument/adapters/{_base.py => _base/_core.py} (100%) create mode 100644 src/layerlens/instrument/adapters/_base/resilience.py create mode 100644 tests/instrument/adapters/_base/__init__.py create mode 100644 tests/instrument/adapters/_base/test_per_adapter_resilience.py create mode 100644 tests/instrument/adapters/_base/test_resilience.py diff --git a/src/layerlens/instrument/adapters/_base/__init__.py b/src/layerlens/instrument/adapters/_base/__init__.py new file mode 100644 index 00000000..c2775780 --- /dev/null +++ b/src/layerlens/instrument/adapters/_base/__init__.py @@ -0,0 +1,28 @@ +"""Shared base infrastructure for all instrument adapters. + +This package re-exports the public surface of the base adapter contract +(``AdapterInfo``, ``BaseAdapter``) and the resilience helpers used by +framework adapters to wrap callbacks in try/except boundaries so that +observability code never breaks the user's framework. +""" + +from __future__ import annotations + +from ._core import AdapterInfo, BaseAdapter +from .resilience import ( + DEFAULT_FAILURE_THRESHOLD, + HealthStatus, + ResilienceTracker, + get_default_for, + resilient_callback, +) + +__all__ = [ + "AdapterInfo", + "BaseAdapter", + "DEFAULT_FAILURE_THRESHOLD", + "HealthStatus", + "ResilienceTracker", + "get_default_for", + "resilient_callback", +] diff --git a/src/layerlens/instrument/adapters/_base.py b/src/layerlens/instrument/adapters/_base/_core.py similarity index 100% rename from src/layerlens/instrument/adapters/_base.py rename to src/layerlens/instrument/adapters/_base/_core.py diff --git a/src/layerlens/instrument/adapters/_base/resilience.py b/src/layerlens/instrument/adapters/_base/resilience.py new file mode 100644 index 00000000..0b57ab3f --- /dev/null +++ b/src/layerlens/instrument/adapters/_base/resilience.py @@ -0,0 +1,394 @@ +"""Per-callback try/except resilience wrapper for adapter callbacks. + +Mature framework adapters (CrewAI, AutoGen, OpenAI Agents, Google ADK, +Strands, Bedrock Agents) wrap every observability callback in a +try/except boundary so an exception in the adapter never escapes back +into the framework's own execution path. Lighter adapters historically +relied on outer wrappers — meaning a single bug in our callback could +crash a customer's agent run. + +This module exposes a shared decorator (``resilient_callback``) and a +per-adapter failure tracker (``ResilienceTracker``) so every framework +adapter can apply the SAME resilience contract: + +1. Catch ``Exception`` (NOT ``BaseException`` — KeyboardInterrupt / + SystemExit / GeneratorExit must still propagate). +2. Log the exception via the adapter's logger with + ``adapter_name``, ``callback_name``, and a truncated traceback. +3. Increment the adapter's ``_resilience_failures`` counter. +4. Return the framework's expected default value for the callback so + the framework continues uninterrupted. + +The failure counter is consulted by ``ResilienceTracker.health_status`` +which returns ``HealthStatus.DEGRADED`` once the adapter has crossed +``DEFAULT_FAILURE_THRESHOLD`` failures within the lifetime of the run. +Adapters surface this in their ``adapter_info().metadata`` block. + +This module is **adapter-internal infrastructure**. It is NOT public +API for end users — there are no version guarantees on the helpers +exposed here, only on the BaseAdapter contract. +""" + +from __future__ import annotations + +import enum +import logging +import functools +import threading +import traceback +from typing import Any, Dict, TypeVar, Callable, Optional, cast + +log = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Public constants & enums +# --------------------------------------------------------------------------- + + +DEFAULT_FAILURE_THRESHOLD: int = 5 +"""Number of resilience failures before an adapter is marked DEGRADED. + +Chosen as a balance between fast detection (catch persistent bugs in +adapter wiring quickly) and not flapping on transient framework quirks +(a single bad event from a flaky upstream shouldn't degrade the entire +adapter). Adapters can override this via ``ResilienceTracker(threshold=...)``. +""" + + +_TRACEBACK_TRUNCATION: int = 4000 +"""Maximum characters of formatted traceback to log per failure. + +Prevents log spam from huge tracebacks (deep async stacks under +LangGraph or LlamaIndex can produce >10kB tracebacks per failure). +""" + + +class HealthStatus(str, enum.Enum): + """Adapter health states surfaced via ``adapter_info().metadata``.""" + + HEALTHY = "healthy" + DEGRADED = "degraded" + + +# --------------------------------------------------------------------------- +# Default-value table for known callbacks +# --------------------------------------------------------------------------- +# +# Many framework callback APIs require the callback to RETURN something +# (not just produce side-effects). For example: +# * Pydantic-AI ``after_model_request`` is expected to return the +# (possibly-mutated) response object — returning ``None`` would replace +# the LLM response with ``None`` and break the agent. +# * Pydantic-AI ``before_tool_execute`` returns the (possibly-mutated) +# args tuple — returning ``None`` would erase the tool args. +# * Google ADK plugin callbacks are documented as returning ``None`` +# (no override semantics) — ``None`` is the correct default. +# * Strands hook callbacks return ``None``. +# * boto3 event-system handlers (Bedrock Agents) return ``None``. +# +# When a callback that needs a passthrough (e.g. Pydantic-AI mutating +# hooks) raises, returning ``None`` would corrupt the framework's data +# flow. Adapters can pass ``passthrough_arg`` to ``resilient_callback`` +# so the wrapper returns that argument's value instead of the default. + +_DEFAULTS: Dict[str, Any] = { + # Google ADK plugin callbacks — all return None (no override hook). + "before_run_callback": None, + "after_run_callback": None, + "before_agent_callback": None, + "after_agent_callback": None, + "before_model_callback": None, + "after_model_callback": None, + "on_model_error_callback": None, + "before_tool_callback": None, + "after_tool_callback": None, + "on_tool_error_callback": None, + "on_event_callback": None, + # Strands hooks — sync, return None. + "_on_agent_initialized": None, + "_on_before_invocation": None, + "_on_after_invocation": None, + "_on_before_model": None, + "_on_after_model": None, + "_on_before_tool": None, + "_on_after_tool": None, + # OpenAI Agents TracingProcessor — return None. + "on_trace_start": None, + "on_trace_end": None, + "on_span_start": None, + "on_span_end": None, + "shutdown": None, + "force_flush": None, + # boto3 event handlers (Bedrock Agents). + "_before_invoke": None, + "_after_invoke": None, +} + + +def get_default_for(callback_name: str) -> Any: + """Return the framework-expected default for *callback_name*, or ``None``. + + The default of ``None`` is correct for the overwhelming majority of + callback APIs across instrumented frameworks (boto3 event system, + LlamaIndex span/event handlers, Strands hooks, Google ADK plugins, + OpenAI Agents TracingProcessor). For callbacks that need to return a + passthrough value (Pydantic-AI mutating hooks), use ``resilient_callback`` + with ``passthrough_arg`` instead. + """ + return _DEFAULTS.get(callback_name) + + +# --------------------------------------------------------------------------- +# Failure tracker +# --------------------------------------------------------------------------- + + +class ResilienceTracker: + """Per-adapter failure counter + degraded-health surface. + + Each framework adapter instantiates one tracker (in ``__init__``). + ``resilient_callback`` records failures via :meth:`record_failure`. + The adapter's ``adapter_info()`` reports current health via + :meth:`health_status` and a snapshot of recent failures via + :meth:`as_metadata`. + + The tracker is thread-safe: framework callbacks can fire from worker + threads (CrewAI dispatches across threads, AutoGen group chat fans + out, Bedrock boto3 hooks run in the request thread). + """ + + def __init__( + self, + adapter_name: str, + threshold: int = DEFAULT_FAILURE_THRESHOLD, + ) -> None: + if threshold < 1: + raise ValueError("threshold must be >= 1") + self._adapter_name = adapter_name + self._threshold = threshold + self._lock = threading.Lock() + self._total_failures: int = 0 + self._per_callback_failures: Dict[str, int] = {} + self._last_error: Optional[str] = None + self._last_callback: Optional[str] = None + + # -- recording -------------------------------------------------------- + + def record_failure(self, callback_name: str, exc: BaseException) -> None: + """Atomically record a failed callback invocation.""" + with self._lock: + self._total_failures += 1 + self._per_callback_failures[callback_name] = self._per_callback_failures.get(callback_name, 0) + 1 + self._last_callback = callback_name + self._last_error = f"{type(exc).__name__}: {exc}"[:500] + + def reset(self) -> None: + """Clear all failure state. Adapters call this on ``disconnect()``.""" + with self._lock: + self._total_failures = 0 + self._per_callback_failures.clear() + self._last_error = None + self._last_callback = None + + # -- queries ---------------------------------------------------------- + + @property + def total_failures(self) -> int: + with self._lock: + return self._total_failures + + @property + def threshold(self) -> int: + return self._threshold + + def health_status(self) -> HealthStatus: + """Return DEGRADED once total failures cross the threshold.""" + with self._lock: + return HealthStatus.DEGRADED if self._total_failures >= self._threshold else HealthStatus.HEALTHY + + def as_metadata(self) -> Dict[str, Any]: + """Snapshot for inclusion in ``adapter_info().metadata``.""" + with self._lock: + data: Dict[str, Any] = { + "resilience_status": ( + HealthStatus.DEGRADED.value if self._total_failures >= self._threshold else HealthStatus.HEALTHY.value + ), + "resilience_failures_total": self._total_failures, + "resilience_failure_threshold": self._threshold, + } + if self._per_callback_failures: + # Cap to top 20 so metadata payloads don't explode for + # adapters with many distinct callbacks. + top = sorted( + self._per_callback_failures.items(), + key=lambda kv: kv[1], + reverse=True, + )[:20] + data["resilience_failures_by_callback"] = dict(top) + if self._last_error is not None: + data["resilience_last_error"] = self._last_error + if self._last_callback is not None: + data["resilience_last_callback"] = self._last_callback + return data + + +# --------------------------------------------------------------------------- +# The decorator +# --------------------------------------------------------------------------- + + +F = TypeVar("F", bound=Callable[..., Any]) + + +def resilient_callback( + *, + callback_name: Optional[str] = None, + default: Any = None, + passthrough_arg: Optional[str] = None, + logger: Optional[logging.Logger] = None, +) -> Callable[[F], F]: + """Wrap a bound adapter method so observability errors never escape. + + The decorator must be applied to *instance methods* of an adapter + class. The adapter MUST expose: + + * ``self.name`` (or fall back to the class name) for logging context + * ``self._resilience`` (a :class:`ResilienceTracker`) for failure + recording + + On exception inside the wrapped method: + + 1. The exception is caught (excluding ``BaseException`` subclasses + like ``KeyboardInterrupt``). + 2. ``self._resilience.record_failure(name, exc)`` is invoked. + 3. The exception is logged via *logger* (or the adapter's module + logger) at WARNING level with a truncated traceback. + 4. The wrapper returns ``default``, OR the value of the keyword/positional + argument named *passthrough_arg* (so frameworks that expect a + mutating callback to return the passed-through value still work). + + Parameters + ---------- + callback_name: + Name to use in failure tracking and log records. Defaults to the + wrapped function's ``__name__``. + default: + Value to return when the wrapped method raises. + Use the framework's expected return type for this callback — + e.g. ``None`` for void handlers, ``""`` for handlers expected to + return a string, the original ``args`` tuple for mutating hooks. + For common callback names, the table in :func:`get_default_for` + provides the canonical default. + passthrough_arg: + If set, the wrapper returns the value of this argument (looked + up in *kwargs* first, then matched positionally if needed) on + failure. Use this for mutating hooks (Pydantic-AI + ``after_model_request`` returns the response object; + ``before_tool_execute`` returns the args tuple). When both + *passthrough_arg* and *default* are set, *passthrough_arg* wins + when the argument is present; otherwise *default* is used. + logger: + Logger to emit failure messages to. Defaults to the module + logger of the wrapped function. + """ + + def _decorate(func: F) -> F: + cb_name = callback_name or func.__name__ + # Resolve logger lazily — the wrapped function's module is the + # right logger context for warnings (so users can mute one + # adapter's resilience warnings without muting all of them). + bound_logger = logger or logging.getLogger(func.__module__) + + @functools.wraps(func) + def _wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + try: + return func(self, *args, **kwargs) + except Exception as exc: # noqa: BLE001 — intentional broad catch + _on_failure( + self, + cb_name=cb_name, + exc=exc, + bound_logger=bound_logger, + ) + return _resolve_return_value( + args=args, + kwargs=kwargs, + func=func, + passthrough_arg=passthrough_arg, + default=default, + ) + + return cast(F, _wrapper) + + return _decorate + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _on_failure( + adapter: Any, + *, + cb_name: str, + exc: BaseException, + bound_logger: logging.Logger, +) -> None: + """Record + log a callback failure on *adapter*'s resilience tracker. + + Best-effort: if the adapter doesn't have a tracker (programming + error), we still log the failure so the user sees it. + """ + adapter_name = getattr(adapter, "name", None) or type(adapter).__name__ + tracker = getattr(adapter, "_resilience", None) + if isinstance(tracker, ResilienceTracker): + tracker.record_failure(cb_name, exc) + + tb = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + if len(tb) > _TRACEBACK_TRUNCATION: + tb = tb[: _TRACEBACK_TRUNCATION - 24] + "\n... [traceback truncated]" + + bound_logger.warning( + "layerlens: resilient_callback caught exception in %s.%s: %s\n%s", + adapter_name, + cb_name, + exc, + tb, + ) + + +def _resolve_return_value( + *, + args: tuple[Any, ...], + kwargs: Dict[str, Any], + func: Callable[..., Any], + passthrough_arg: Optional[str], + default: Any, +) -> Any: + """Compute the value to return when a wrapped callback raises. + + If *passthrough_arg* names a parameter that was actually supplied, + return its value. Otherwise return *default*. + """ + if not passthrough_arg: + return default + + # Keyword-supplied arguments are the most common case for callback + # APIs (Pydantic-AI / Google ADK / Strands all use keyword-only + # callback signatures). + if passthrough_arg in kwargs: + return kwargs[passthrough_arg] + + # Fall back to positional resolution by inspecting the function's + # parameter list (skip ``self`` which is always position 0). + try: + params = func.__code__.co_varnames[: func.__code__.co_argcount] + except AttributeError: + return default + for index, name in enumerate(params): + if name == passthrough_arg and index >= 1 and index - 1 < len(args): + return args[index - 1] + + return default diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 2e7b3e36..9de15a01 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -12,7 +12,7 @@ import threading from typing import Any, Dict, Optional -from .._base import AdapterInfo, BaseAdapter +from .._base import AdapterInfo, BaseAdapter, ResilienceTracker from ..._context import ( RunState, _pop_span, @@ -55,6 +55,18 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) self._connected = False # Subclasses populate during connect() for adapter_info() metadata self._metadata: Dict[str, Any] = {} + # Resilience: every framework adapter gets a per-instance tracker so + # @resilient_callback wrappers can record failures without each + # subclass needing to opt in. Use ``self.name`` when the subclass + # has set it (class-level), otherwise fall back to the class name. + adapter_name = getattr(type(self), "name", None) or type(self).__name__ + self._resilience: ResilienceTracker = ResilienceTracker(adapter_name) + # Public, mypy-friendly alias of the failure counter — kept as a + # property-shaped int so external callers can read it without + # importing ResilienceTracker. + # (Intentionally not @property — keeping a plain attribute would + # have masked the tracker's lock; readers should call + # ``self._resilience.total_failures`` for the up-to-date count.) # ------------------------------------------------------------------ # Per-run state (ContextVar-based isolation for concurrent runs) @@ -303,15 +315,24 @@ def disconnect(self) -> None: self._on_disconnect() self._connected = False self._metadata.clear() + # Reset resilience state so a reconnect starts from a healthy + # baseline. Failures from a previous run shouldn't degrade a + # fresh adapter session. + self._resilience.reset() def _on_disconnect(self) -> None: """Override to clean up framework-specific resources (unsubscribe, restore, etc.).""" pass def adapter_info(self) -> AdapterInfo: + # Merge live resilience snapshot into the metadata block so + # ``adapter_info().metadata['resilience_status']`` reports + # HEALTHY / DEGRADED to monitoring code without each subclass + # having to remember to do it. + merged_metadata: Dict[str, Any] = {**self._metadata, **self._resilience.as_metadata()} return AdapterInfo( name=self.name, adapter_type="framework", connected=self._connected, - metadata=self._metadata, + metadata=merged_metadata, ) diff --git a/src/layerlens/instrument/adapters/frameworks/agno.py b/src/layerlens/instrument/adapters/frameworks/agno.py index 8aa4e027..ba10fcd2 100644 --- a/src/layerlens/instrument/adapters/frameworks/agno.py +++ b/src/layerlens/instrument/adapters/frameworks/agno.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict, List, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ._base_framework import FrameworkAdapter from ..._capture_config import CaptureConfig @@ -245,6 +246,7 @@ async def _traced_arun(*args: Any, **kwargs: Any) -> Any: # Run lifecycle # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_run_start") def _on_run_start(self, agent: Any, input_data: Any) -> None: root = self._get_root_span() name = _agent_name(agent) @@ -255,6 +257,7 @@ def _on_run_start(self, agent: Any, input_data: Any) -> None: self._set_if_capturing(payload, "input", safe_serialize(input_data)) self._emit("agent.input", payload, span_id=root, parent_span_id=None, span_name=f"agno:{name}") + @resilient_callback(callback_name="_on_run_end") def _on_run_end(self, agent: Any, result: Any, error: Optional[Exception]) -> None: self._emit_output(agent, result, error) if result is not None: diff --git a/src/layerlens/instrument/adapters/frameworks/bedrock_agents.py b/src/layerlens/instrument/adapters/frameworks/bedrock_agents.py index 96f18829..4d928b17 100644 --- a/src/layerlens/instrument/adapters/frameworks/bedrock_agents.py +++ b/src/layerlens/instrument/adapters/frameworks/bedrock_agents.py @@ -3,6 +3,7 @@ import logging from typing import Any, Set, Dict, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ._base_framework import FrameworkAdapter from ..._capture_config import CaptureConfig @@ -119,63 +120,67 @@ def _on_disconnect(self) -> None: # boto3 event hooks # ------------------------------------------------------------------ + @resilient_callback(callback_name="_before_invoke") def _before_invoke(self, **kwargs: Any) -> None: if not self._connected: return - try: - params = kwargs.get("params", {}) - agent_id = params.get("agentId", "unknown") + params = kwargs.get("params", {}) + agent_id = params.get("agentId", "unknown") - self._begin_run() - self._start_timer("invoke") + self._begin_run() + self._start_timer("invoke") - self._emit_agent_config(agent_id, params) + self._emit_agent_config(agent_id, params) - root = self._get_root_span() - payload = self._payload( - agent_id=agent_id, - session_id=params.get("sessionId"), - enable_trace=params.get("enableTrace", False), - ) - self._set_if_capturing(payload, "input", params.get("inputText")) - self._emit( - "agent.input", - payload, - span_id=root, - parent_span_id=None, - span_name="bedrock.invoke_agent", - ) - except Exception: - log.warning("layerlens: error in _before_invoke", exc_info=True) + root = self._get_root_span() + payload = self._payload( + agent_id=agent_id, + session_id=params.get("sessionId"), + enable_trace=params.get("enableTrace", False), + ) + self._set_if_capturing(payload, "input", params.get("inputText")) + self._emit( + "agent.input", + payload, + span_id=root, + parent_span_id=None, + span_name="bedrock.invoke_agent", + ) def _after_invoke(self, **kwargs: Any) -> None: + # _end_run() MUST run regardless of telemetry failures (otherwise + # collector/span ContextVars leak across boto3 calls). Keep the + # ``finally`` here at the OUTER level and delegate the resilient + # body to a helper wrapped with @resilient_callback. if not self._connected: return try: - parsed = kwargs.get("parsed", {}) - latency_ms = self._stop_timer("invoke") - output = _extract_completion(parsed) - - root = self._get_root_span() - payload = self._payload(session_id=parsed.get("sessionId")) - if latency_ms is not None: - payload["latency_ms"] = latency_ms - self._set_if_capturing(payload, "output", output) - self._emit( - "agent.output", - payload, - span_id=root, - parent_span_id=None, - span_name="bedrock.invoke_agent", - ) - - for step in _collect_steps(parsed): - self._process_step(step) - except Exception: - log.warning("layerlens: error in _after_invoke", exc_info=True) + self._after_invoke_body(**kwargs) finally: self._end_run() + @resilient_callback(callback_name="_after_invoke") + def _after_invoke_body(self, **kwargs: Any) -> None: + parsed = kwargs.get("parsed", {}) + latency_ms = self._stop_timer("invoke") + output = _extract_completion(parsed) + + root = self._get_root_span() + payload = self._payload(session_id=parsed.get("sessionId")) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + self._set_if_capturing(payload, "output", output) + self._emit( + "agent.output", + payload, + span_id=root, + parent_span_id=None, + span_name="bedrock.invoke_agent", + ) + + for step in _collect_steps(parsed): + self._process_step(step) + # ------------------------------------------------------------------ # Trace step processing # ------------------------------------------------------------------ diff --git a/src/layerlens/instrument/adapters/frameworks/google_adk.py b/src/layerlens/instrument/adapters/frameworks/google_adk.py index 9c494050..396b4380 100644 --- a/src/layerlens/instrument/adapters/frameworks/google_adk.py +++ b/src/layerlens/instrument/adapters/frameworks/google_adk.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -134,6 +135,7 @@ def _end_trace(self) -> None: # Run lifecycle handlers (called from plugin) # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_before_run") def _on_before_run(self, invocation_context: Any) -> None: span_id = self._new_span_id() with self._lock: @@ -176,6 +178,7 @@ def _on_before_run(self, invocation_context: Any) -> None: self._set_if_capturing(payload, "input", safe_serialize(user_content)) self._fire("agent.input", payload, span_id=span_id, span_name=agent_name) + @resilient_callback(callback_name="_on_after_run") def _on_after_run(self, invocation_context: Any) -> None: latency_ms = self._tock("run") span_id = self._run_span_id or self._new_span_id() @@ -191,6 +194,7 @@ def _on_after_run(self, invocation_context: Any) -> None: # Agent lifecycle handlers # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_before_agent") def _on_before_agent(self, agent: Any, callback_context: Any) -> None: name = _agent_name(agent) span_id = self._new_span_id() @@ -206,6 +210,7 @@ def _on_before_agent(self, agent: Any, callback_context: Any) -> None: self._set_if_capturing(payload, "input", safe_serialize(user_content)) self._fire("agent.input", payload, span_id=span_id, parent_span_id=self._run_span_id, span_name=f"agent:{name}") + @resilient_callback(callback_name="_on_after_agent") def _on_after_agent(self, agent: Any, callback_context: Any) -> None: name = _agent_name(agent) latency_ms = self._tock(f"agent:{name}") @@ -225,10 +230,12 @@ def _on_after_agent(self, agent: Any, callback_context: Any) -> None: # Model lifecycle handlers # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_before_model") def _on_before_model(self, callback_context: Any, llm_request: Any) -> None: agent_name = getattr(callback_context, "agent_name", None) or "unknown" self._tick(f"model:{agent_name}") + @resilient_callback(callback_name="_on_after_model") def _on_after_model(self, callback_context: Any, llm_response: Any) -> None: agent_name = getattr(callback_context, "agent_name", None) or "unknown" latency_ms = self._tock(f"model:{agent_name}") @@ -267,6 +274,7 @@ def _on_after_model(self, callback_context: Any, llm_response: Any) -> None: cost_payload["model"] = str(model) self._fire("cost.record", cost_payload, span_id=span_id, parent_span_id=parent) + @resilient_callback(callback_name="_on_model_error") def _on_model_error(self, callback_context: Any, llm_request: Any, error: Exception) -> None: agent_name = getattr(callback_context, "agent_name", None) or "unknown" self._tock(f"model:{agent_name}") # clear timer @@ -280,11 +288,13 @@ def _on_model_error(self, callback_context: Any, llm_request: Any, error: Except # Tool lifecycle handlers # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_before_tool") def _on_before_tool(self, tool: Any, tool_args: Any, tool_context: Any) -> None: tool_name = getattr(tool, "name", None) or "unknown" call_id = getattr(tool_context, "function_call_id", None) or tool_name self._tick(f"tool:{call_id}") + @resilient_callback(callback_name="_on_after_tool") def _on_after_tool(self, tool: Any, tool_args: Any, tool_context: Any, result: Any) -> None: tool_name = getattr(tool, "name", None) or "unknown" call_id = getattr(tool_context, "function_call_id", None) or tool_name @@ -303,6 +313,7 @@ def _on_after_tool(self, tool: Any, tool_args: Any, tool_context: Any, result: A self._set_if_capturing(result_payload, "output", safe_serialize(result)) self._fire("tool.result", result_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}") + @resilient_callback(callback_name="_on_tool_error") def _on_tool_error(self, tool: Any, tool_args: Any, tool_context: Any, error: Exception) -> None: tool_name = getattr(tool, "name", None) or "unknown" call_id = getattr(tool_context, "function_call_id", None) or tool_name @@ -317,6 +328,7 @@ def _on_tool_error(self, tool: Any, tool_args: Any, tool_context: Any, error: Ex # Event callback # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_event") def _on_event(self, invocation_context: Any, event: Any) -> None: # Detect agent handoffs from event actions actions = getattr(event, "actions", None) @@ -378,86 +390,58 @@ def _make_plugin(adapter: GoogleADKAdapter) -> Any: if _BasePlugin is None: raise ImportError("google-adk is required for GoogleADKAdapter") - class _LayerLensPlugin(_BasePlugin): + # The adapter's ``_on_*`` methods are wrapped with ``@resilient_callback`` + # which catches exceptions, logs them, and increments the resilience + # tracker — so the plugin shims can call them directly without any + # additional try/except. The shims still need to ``return None`` + # (the ADK plugin contract requires None to mean "don't override"). + class _LayerLensPlugin(_BasePlugin): # type: ignore[misc, valid-type] def __init__(self) -> None: super().__init__(name="layerlens") async def before_run_callback(self, *, invocation_context: Any) -> None: - try: - adapter._on_before_run(invocation_context) - except Exception: - log.warning("layerlens: error in before_run_callback", exc_info=True) + adapter._on_before_run(invocation_context) return None async def after_run_callback(self, *, invocation_context: Any) -> None: - try: - adapter._on_after_run(invocation_context) - except Exception: - log.warning("layerlens: error in after_run_callback", exc_info=True) + adapter._on_after_run(invocation_context) async def before_agent_callback(self, *, agent: Any, callback_context: Any) -> None: - try: - adapter._on_before_agent(agent, callback_context) - except Exception: - log.warning("layerlens: error in before_agent_callback", exc_info=True) + adapter._on_before_agent(agent, callback_context) return None async def after_agent_callback(self, *, agent: Any, callback_context: Any) -> None: - try: - adapter._on_after_agent(agent, callback_context) - except Exception: - log.warning("layerlens: error in after_agent_callback", exc_info=True) + adapter._on_after_agent(agent, callback_context) return None async def before_model_callback(self, *, callback_context: Any, llm_request: Any) -> None: - try: - adapter._on_before_model(callback_context, llm_request) - except Exception: - log.warning("layerlens: error in before_model_callback", exc_info=True) + adapter._on_before_model(callback_context, llm_request) return None async def after_model_callback(self, *, callback_context: Any, llm_response: Any) -> None: - try: - adapter._on_after_model(callback_context, llm_response) - except Exception: - log.warning("layerlens: error in after_model_callback", exc_info=True) + adapter._on_after_model(callback_context, llm_response) return None async def on_model_error_callback(self, *, callback_context: Any, llm_request: Any, error: Exception) -> None: - try: - adapter._on_model_error(callback_context, llm_request, error) - except Exception: - log.warning("layerlens: error in on_model_error_callback", exc_info=True) + adapter._on_model_error(callback_context, llm_request, error) return None async def before_tool_callback(self, *, tool: Any, tool_args: Any, tool_context: Any) -> None: - try: - adapter._on_before_tool(tool, tool_args, tool_context) - except Exception: - log.warning("layerlens: error in before_tool_callback", exc_info=True) + adapter._on_before_tool(tool, tool_args, tool_context) return None async def after_tool_callback(self, *, tool: Any, tool_args: Any, tool_context: Any, result: Any) -> None: - try: - adapter._on_after_tool(tool, tool_args, tool_context, result) - except Exception: - log.warning("layerlens: error in after_tool_callback", exc_info=True) + adapter._on_after_tool(tool, tool_args, tool_context, result) return None async def on_tool_error_callback( self, *, tool: Any, tool_args: Any, tool_context: Any, error: Exception ) -> None: - try: - adapter._on_tool_error(tool, tool_args, tool_context, error) - except Exception: - log.warning("layerlens: error in on_tool_error_callback", exc_info=True) + adapter._on_tool_error(tool, tool_args, tool_context, error) return None async def on_event_callback(self, *, invocation_context: Any, event: Any) -> None: - try: - adapter._on_event(invocation_context, event) - except Exception: - log.warning("layerlens: error in on_event_callback", exc_info=True) + adapter._on_event(invocation_context, event) return None return _LayerLensPlugin() diff --git a/src/layerlens/instrument/adapters/frameworks/haystack.py b/src/layerlens/instrument/adapters/frameworks/haystack.py index 10ee412f..a24c8bf6 100644 --- a/src/layerlens/instrument/adapters/frameworks/haystack.py +++ b/src/layerlens/instrument/adapters/frameworks/haystack.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Iterator, Optional from contextlib import contextmanager +from .._base import resilient_callback from ._utils import safe_serialize from ._base_framework import FrameworkAdapter from ..._capture_config import CaptureConfig @@ -70,15 +71,13 @@ def _on_disconnect(self) -> None: # Span handlers (called by _LayerLensSpan._finish) # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_span_end") def _on_span_end(self, span: _LayerLensSpan) -> None: elapsed_ms = (time.time_ns() - span._start_ns) / 1_000_000 - try: - if span._is_pipeline: - self._on_pipeline_end(span, elapsed_ms) - elif span._operation_name == "haystack.component.run": - self._on_component_end(span, elapsed_ms) - except Exception: - log.warning("layerlens: error emitting Haystack span", exc_info=True) + if span._is_pipeline: + self._on_pipeline_end(span, elapsed_ms) + elif span._operation_name == "haystack.component.run": + self._on_component_end(span, elapsed_ms) def _on_pipeline_end(self, span: _LayerLensSpan, elapsed_ms: float) -> None: tags = span._all_tags() diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse.py b/src/layerlens/instrument/adapters/frameworks/langfuse.py index cdfe4611..35886f3a 100644 --- a/src/layerlens/instrument/adapters/frameworks/langfuse.py +++ b/src/layerlens/instrument/adapters/frameworks/langfuse.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, List, Optional +from .._base import resilient_callback from ._utils import truncate, new_span_id from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -174,15 +175,14 @@ def import_traces( imported = 0 for trace_summary in traces: - try: - self._import_single_trace(trace_summary) + # ``_import_single_trace`` is wrapped with @resilient_callback + # — a malformed trace becomes a logged warning + failure + # counter increment, NOT a halt of the batch import. Track + # success via the success counter we maintain manually. + before = self._resilience.total_failures + self._import_single_trace(trace_summary) + if self._resilience.total_failures == before: imported += 1 - except Exception: - log.warning( - "layerlens: failed to import Langfuse trace %s", - trace_summary.get("id", "?"), - exc_info=True, - ) # Advance cursor to the most recent trace timestamp latest = traces[0].get("updatedAt") or traces[0].get("timestamp") @@ -192,6 +192,7 @@ def import_traces( log.info("layerlens: imported %d Langfuse traces", imported) return imported + @resilient_callback(callback_name="_import_single_trace") def _import_single_trace(self, trace_summary: Dict[str, Any]) -> None: """Fetch a full trace and emit events via TraceCollector.""" trace_id = trace_summary["id"] @@ -219,48 +220,17 @@ def _import_single_trace(self, trace_summary: Dict[str, Any]) -> None: span_name=trace.get("name"), ) - # Process observations (generations, spans, events) + # Process observations (generations, spans, events). Inner call + # is wrapped with @resilient_callback — a malformed observation + # is logged + counted, not propagated. observations = trace.get("observations", []) for obs in observations: - try: - self._import_observation(collector, obs, root_span_id) - except Exception: - log.warning( - "layerlens: failed to import observation %s", - obs.get("id", "?"), - exc_info=True, - ) + self._import_observation(collector, obs, root_span_id) - # Scores (Langfuse "annotations") — both human annotations and LLM-as-judge - # scores land in the same collection. Emit them as evaluation.result so - # the migration path preserves all grading signal. + # Scores (Langfuse "annotations") — wrapped in a resilient + # helper so one bad score doesn't abort the whole trace import. for score in trace.get("scores", []) or []: - try: - score_payload: Dict[str, Any] = { - "framework": "langfuse", - "langfuse_trace_id": trace_id, - "name": score.get("name"), - "value": score.get("value"), - "source": score.get("source"), - "data_type": score.get("dataType"), - "observation_id": score.get("observationId"), - } - comment = score.get("comment") - if comment: - score_payload["comment"] = truncate(str(comment), max_len=2000) - # Session clustering: Langfuse groups related traces via sessionId. - # Carry it through so downstream session-level analytics work. - session_id = score.get("sessionId") or trace.get("sessionId") - if session_id: - score_payload["session_id"] = session_id - collector.emit( - "evaluation.result", - score_payload, - span_id=new_span_id(), - parent_span_id=root_span_id, - ) - except Exception: - log.warning("layerlens: failed to import score", exc_info=True) + self._import_score(collector, trace, trace_id, root_span_id, score) # Emit agent.output from trace output trace_output = trace.get("output") @@ -283,6 +253,41 @@ def _import_single_trace(self, trace_summary: Dict[str, Any]) -> None: collector.flush() + @resilient_callback(callback_name="_import_score") + def _import_score( + self, + collector: TraceCollector, + trace: Dict[str, Any], + trace_id: str, + root_span_id: str, + score: Dict[str, Any], + ) -> None: + """Emit one Langfuse score as a LayerLens evaluation.result event.""" + score_payload: Dict[str, Any] = { + "framework": "langfuse", + "langfuse_trace_id": trace_id, + "name": score.get("name"), + "value": score.get("value"), + "source": score.get("source"), + "data_type": score.get("dataType"), + "observation_id": score.get("observationId"), + } + comment = score.get("comment") + if comment: + score_payload["comment"] = truncate(str(comment), max_len=2000) + # Session clustering: Langfuse groups related traces via sessionId. + # Carry it through so downstream session-level analytics work. + session_id = score.get("sessionId") or trace.get("sessionId") + if session_id: + score_payload["session_id"] = session_id + collector.emit( + "evaluation.result", + score_payload, + span_id=new_span_id(), + parent_span_id=root_span_id, + ) + + @resilient_callback(callback_name="_import_observation") def _import_observation( self, collector: TraceCollector, @@ -472,21 +477,26 @@ def export_traces( exported = 0 for trace_id, events in events_by_trace.items(): - try: - batch = self._build_ingestion_batch(trace_id, events) - if batch: - self._post_ingestion(batch) - exported += 1 - except Exception: - log.warning( - "layerlens: failed to export trace %s to Langfuse", - trace_id, - exc_info=True, - ) + before = self._resilience.total_failures + self._export_single_trace(trace_id, events) + if self._resilience.total_failures == before: + exported += 1 log.info("layerlens: exported %d traces to Langfuse", exported) return exported + @resilient_callback(callback_name="_export_single_trace") + def _export_single_trace(self, trace_id: str, events: List[Dict[str, Any]]) -> None: + """Build + POST a single trace's ingestion batch. + + Wrapped with @resilient_callback so a single bad trace doesn't + abort the rest of the batch export. The success/failure of each + trace is tracked via the resilience tracker. + """ + batch = self._build_ingestion_batch(trace_id, events) + if batch: + self._post_ingestion(batch) + def _build_ingestion_batch( self, trace_id: str, diff --git a/src/layerlens/instrument/adapters/frameworks/llamaindex.py b/src/layerlens/instrument/adapters/frameworks/llamaindex.py index 53f1071f..008e02a4 100644 --- a/src/layerlens/instrument/adapters/frameworks/llamaindex.py +++ b/src/layerlens/instrument/adapters/frameworks/llamaindex.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, List, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -166,6 +167,7 @@ def _flush_all(self) -> None: # Span lifecycle (called by the thin span handler) # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_span_enter") def _on_span_enter(self, id_: str, parent_span_id: Optional[str]) -> Any: with self._lock: span = _BaseSpan(id_=id_, parent_id=parent_span_id) @@ -175,6 +177,7 @@ def _on_span_enter(self, id_: str, parent_span_id: Optional[str]) -> Any: self._collectors[id_] = TraceCollector(self._client, self._config) return span + @resilient_callback(callback_name="_on_span_exit") def _on_span_exit(self, id_: str) -> Any: with self._lock: span = self._open_spans.get(id_) @@ -184,6 +187,7 @@ def _on_span_exit(self, id_: str) -> Any: collector.flush() return span + @resilient_callback(callback_name="_on_span_drop") def _on_span_drop(self, id_: str) -> Any: return self._on_span_exit(id_) # same cleanup @@ -191,23 +195,29 @@ def _on_span_drop(self, id_: str) -> Any: # Event dispatch (called by the thin event handler) # ------------------------------------------------------------------ + @resilient_callback(callback_name="_handle_event") def _handle_event(self, event: Any) -> None: - try: - handler_name = self._EVENT_DISPATCH.get(type(event).__name__) - if handler_name is not None: - getattr(self, handler_name)(event) - except Exception: - log.warning("layerlens: error in LlamaIndex event handler", exc_info=True) + # Per-event handlers are individually wrapped (defense-in-depth) + # so each failed handler is recorded with its real name in the + # resilience tracker rather than being aggregated under + # ``_handle_event``. The outer wrapper here covers any failure + # in the dispatch/lookup logic itself (unknown event class, + # ``getattr`` raising, etc.). + handler_name = self._EVENT_DISPATCH.get(type(event).__name__) + if handler_name is not None: + getattr(self, handler_name)(event) # ------------------------------------------------------------------ # LLM Chat # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_llm_chat_start") def _on_llm_chat_start(self, event: Any) -> None: span_id = getattr(event, "span_id", None) if span_id: self._llm_start_times[span_id] = time.time() + @resilient_callback(callback_name="_on_llm_chat_end") def _on_llm_chat_end(self, event: Any) -> None: span_id = getattr(event, "span_id", None) response = getattr(event, "response", None) @@ -246,11 +256,13 @@ def _on_llm_chat_end(self, event: Any) -> None: # LLM Completion # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_llm_completion_start") def _on_llm_completion_start(self, event: Any) -> None: span_id = getattr(event, "span_id", None) if span_id: self._llm_start_times[span_id] = time.time() + @resilient_callback(callback_name="_on_llm_completion_end") def _on_llm_completion_end(self, event: Any) -> None: span_id = getattr(event, "span_id", None) response = getattr(event, "response", None) @@ -289,6 +301,7 @@ def _on_llm_completion_end(self, event: Any) -> None: # Tool calls # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_tool_call") def _on_tool_call(self, event: Any) -> None: span_id = getattr(event, "span_id", None) tool = getattr(event, "tool", None) @@ -310,6 +323,7 @@ def _on_tool_call(self, event: Any) -> None: # Retrieval # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_retrieval_start") def _on_retrieval_start(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload(tool_name="retrieval") @@ -319,6 +333,7 @@ def _on_retrieval_start(self, event: Any) -> None: payload["input"] = str(query) self._fire("tool.call", payload, span_id=span_id, span_name="retrieval") + @resilient_callback(callback_name="_on_retrieval_end") def _on_retrieval_end(self, event: Any) -> None: span_id = getattr(event, "span_id", None) nodes = getattr(event, "nodes", None) @@ -333,6 +348,7 @@ def _on_retrieval_end(self, event: Any) -> None: # Embeddings # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_embedding_start") def _on_embedding_start(self, event: Any) -> None: # When L3 model metadata is suppressed, skip the costly embedding serialization # — bulk ingestion runs fire thousands of these events and the collector @@ -346,6 +362,7 @@ def _on_embedding_start(self, event: Any) -> None: payload["model"] = model self._fire("model.invoke", payload, span_id=span_id, span_name="embedding") + @resilient_callback(callback_name="_on_embedding_end") def _on_embedding_end(self, event: Any) -> None: if not self._config.l3_model_metadata: return @@ -380,6 +397,7 @@ def _on_embedding_end(self, event: Any) -> None: # Query # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_query_start") def _on_query_start(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload() @@ -389,6 +407,7 @@ def _on_query_start(self, event: Any) -> None: payload["input"] = str(query) self._fire("agent.input", payload, span_id=span_id, span_name="query") + @resilient_callback(callback_name="_on_query_end") def _on_query_end(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload(status="ok") @@ -402,6 +421,7 @@ def _on_query_end(self, event: Any) -> None: # Agent steps # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_agent_step_start") def _on_agent_step_start(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload() @@ -414,6 +434,7 @@ def _on_agent_step_start(self, event: Any) -> None: payload["input"] = safe_serialize(step_input) self._fire("agent.input", payload, span_id=span_id, span_name="agent_step") + @resilient_callback(callback_name="_on_agent_step_end") def _on_agent_step_end(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload(status="ok") @@ -427,6 +448,7 @@ def _on_agent_step_end(self, event: Any) -> None: # Rerank # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_rerank_start") def _on_rerank_start(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload(tool_name="rerank") @@ -438,6 +460,7 @@ def _on_rerank_start(self, event: Any) -> None: payload["top_n"] = top_n self._fire("tool.call", payload, span_id=span_id, span_name="rerank") + @resilient_callback(callback_name="_on_rerank_end") def _on_rerank_end(self, event: Any) -> None: span_id = getattr(event, "span_id", None) payload = self._payload(tool_name="rerank") @@ -450,6 +473,7 @@ def _on_rerank_end(self, event: Any) -> None: # Exceptions # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_exception") def _on_exception(self, event: Any) -> None: span_id = getattr(event, "span_id", None) exc = getattr(event, "exception", None) diff --git a/src/layerlens/instrument/adapters/frameworks/openai_agents.py b/src/layerlens/instrument/adapters/frameworks/openai_agents.py index 9acd00c0..a354824c 100644 --- a/src/layerlens/instrument/adapters/frameworks/openai_agents.py +++ b/src/layerlens/instrument/adapters/frameworks/openai_agents.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Optional from datetime import datetime +from .._base import resilient_callback from ._utils import safe_serialize from ..._context import RunState, _current_run, _current_collector from ..._collector import TraceCollector @@ -83,52 +84,46 @@ def _on_disconnect(self) -> None: # TracingProcessor interface # ------------------------------------------------------------------ + @resilient_callback(callback_name="on_trace_start") def on_trace_start(self, trace: Any) -> None: - try: - # OA manages multiple concurrent traces from one processor, - # so we create RunState directly instead of using _begin_run - # (which would pollute ContextVars for the next trace). - run = RunState( - collector=TraceCollector(self._client, self._config), - root_span_id=self._new_span_id(), - ) - with self._lock: - self._trace_runs[trace.trace_id] = run - except Exception: - log.warning("layerlens: error in on_trace_start", exc_info=True) + # OA manages multiple concurrent traces from one processor, + # so we create RunState directly instead of using _begin_run + # (which would pollute ContextVars for the next trace). + run = RunState( + collector=TraceCollector(self._client, self._config), + root_span_id=self._new_span_id(), + ) + with self._lock: + self._trace_runs[trace.trace_id] = run + @resilient_callback(callback_name="on_trace_end") def on_trace_end(self, trace: Any) -> None: - try: - with self._lock: - run = self._trace_runs.pop(trace.trace_id, None) - if run is not None: - run.collector.flush() - except Exception: - log.warning("layerlens: error in on_trace_end", exc_info=True) + with self._lock: + run = self._trace_runs.pop(trace.trace_id, None) + if run is not None: + run.collector.flush() def on_span_start(self, span: Any) -> None: pass + @resilient_callback(callback_name="on_span_end") def on_span_end(self, span: Any) -> None: + with self._lock: + run = self._trace_runs.get(span.trace_id) + if run is None: + return + + # Temporarily set both ContextVars so _emit and providers work. + run_token = _current_run.set(run) + col_token = _current_collector.set(run.collector) try: - with self._lock: - run = self._trace_runs.get(span.trace_id) - if run is None: - return - - # Temporarily set both ContextVars so _emit and providers work. - run_token = _current_run.set(run) - col_token = _current_collector.set(run.collector) - try: - span_type = getattr(span.span_data, "type", None) or "" - handler_name = self._SPAN_HANDLERS.get(span_type) - if handler_name is not None: - getattr(self, handler_name)(span) - finally: - _current_collector.reset(col_token) - _current_run.reset(run_token) - except Exception: - log.warning("layerlens: error handling OpenAI Agents span", exc_info=True) + span_type = getattr(span.span_data, "type", None) or "" + handler_name = self._SPAN_HANDLERS.get(span_type) + if handler_name is not None: + getattr(self, handler_name)(span) + finally: + _current_collector.reset(col_token) + _current_run.reset(run_token) def shutdown(self) -> None: pass diff --git a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py index 63517dd9..a97913eb 100644 --- a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py +++ b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py @@ -3,6 +3,7 @@ import logging from typing import Any, Dict, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ._base_framework import FrameworkAdapter from ..._capture_config import CaptureConfig @@ -98,6 +99,7 @@ def _register_hooks(self, hooks: Any) -> None: # Run lifecycle hooks # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_before_run") def _on_before_run(self, ctx: Any) -> None: self._begin_run() root = self._get_root_span() @@ -141,6 +143,7 @@ def _on_before_run(self, ctx: Any) -> None: ) self._start_timer("run") + @resilient_callback(callback_name="_on_after_run", passthrough_arg="result") def _on_after_run(self, ctx: Any, *, result: Any) -> Any: latency_ms = self._stop_timer("run") root = self._get_root_span() @@ -176,6 +179,16 @@ def _on_after_run(self, ctx: Any, *, result: Any) -> Any: return result def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: + # Telemetry is best-effort; we MUST always re-raise the + # framework's original error or PydanticAI loses its error + # propagation contract. Keep telemetry inside a resilient + # helper so adapter-side bugs can never swallow the framework + # error. + self._emit_run_error_telemetry(ctx, error=error) + raise error + + @resilient_callback(callback_name="_on_run_error") + def _emit_run_error_telemetry(self, ctx: Any, *, error: BaseException) -> None: latency_ms = self._stop_timer("run") root = self._get_root_span() agent_name = self._get_agent_name(ctx) @@ -196,12 +209,12 @@ def _on_run_error(self, ctx: Any, *, error: BaseException) -> None: ) self._end_run() - raise error # ------------------------------------------------------------------ # Model request hooks # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_after_model_request", passthrough_arg="response") def _on_after_model_request( self, ctx: Any, @@ -240,18 +253,31 @@ def _on_model_request_error( *, request_context: Any, error: Exception, + ) -> None: + # Telemetry first (resilient), THEN re-raise the framework's + # error so PydanticAI's own error propagation is preserved. + self._emit_model_request_error_telemetry(ctx, request_context=request_context, error=error) + raise error + + @resilient_callback(callback_name="_on_model_request_error") + def _emit_model_request_error_telemetry( + self, + ctx: Any, + *, + request_context: Any, + error: Exception, ) -> None: payload = self._payload( error=str(error), error_type=type(error).__name__, ) self._emit("agent.error", payload) - raise error # ------------------------------------------------------------------ # Tool execution hooks # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_before_tool_execute", passthrough_arg="args") def _on_before_tool_execute( self, ctx: Any, @@ -269,6 +295,7 @@ def _on_before_tool_execute( self._start_timer(f"tool:{call_id}") return args + @resilient_callback(callback_name="_on_after_tool_execute", passthrough_arg="result") def _on_after_tool_execute( self, ctx: Any, @@ -301,6 +328,21 @@ def _on_tool_execute_error( tool_def: Any, args: Any, error: Exception, + ) -> None: + # Telemetry first (resilient), THEN re-raise the framework's + # error so PydanticAI can propagate the tool failure. + self._emit_tool_execute_error_telemetry(ctx, call=call, tool_def=tool_def, args=args, error=error) + raise error + + @resilient_callback(callback_name="_on_tool_execute_error") + def _emit_tool_execute_error_telemetry( + self, + ctx: Any, + *, + call: Any, + tool_def: Any, + args: Any, + error: Exception, ) -> None: tool_name = getattr(call, "tool_name", "unknown") call_id = getattr(call, "id", None) or tool_name @@ -316,12 +358,12 @@ def _on_tool_execute_error( error_type=type(error).__name__, ) self._emit("agent.error", payload) - raise error # ------------------------------------------------------------------ # Streaming hooks (pydantic-ai >= 0.5) # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_stream_chunk") def _on_stream_chunk(self, ctx: Any, *, chunk: Any, **_kwargs: Any) -> None: """Accumulate streaming chunks on the RunState; aggregated at stream end.""" run = self._get_run() @@ -330,6 +372,7 @@ def _on_stream_chunk(self, ctx: Any, *, chunk: Any, **_kwargs: Any) -> None: buf = run.data.setdefault("stream_buffer", []) buf.append(chunk) + @resilient_callback(callback_name="_on_after_stream") def _on_after_stream(self, ctx: Any, *, response: Any = None, **_kwargs: Any) -> None: run = self._get_run() if run is None: diff --git a/src/layerlens/instrument/adapters/frameworks/smolagents.py b/src/layerlens/instrument/adapters/frameworks/smolagents.py index 0e9c1e87..0b77f203 100644 --- a/src/layerlens/instrument/adapters/frameworks/smolagents.py +++ b/src/layerlens/instrument/adapters/frameworks/smolagents.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, List, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -190,6 +191,7 @@ def _end_trace(self) -> None: # Run lifecycle handlers # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_run_start") def _on_run_start(self, agent: Any, task: Any) -> None: span_id = self._new_span_id() with self._lock: @@ -219,6 +221,7 @@ def _on_run_start(self, agent: Any, task: Any) -> None: self._set_if_capturing(payload, "input", safe_serialize(task)) self._fire("agent.input", payload, span_id=span_id, span_name=agent_name) + @resilient_callback(callback_name="_on_run_end") def _on_run_end(self, agent: Any, result: Any, error: Optional[Exception]) -> None: latency_ms = self._tock("run") span_id = self._run_span_id or self._new_span_id() @@ -232,6 +235,7 @@ def _on_run_end(self, agent: Any, result: Any, error: Optional[Exception]) -> No self._fire("agent.output", payload, span_id=span_id, span_name=agent_name) self._end_trace() + @resilient_callback(callback_name="_on_run_error") def _on_run_error(self, agent: Any, exc: Exception) -> None: agent_name = _agent_name(agent) self._fire( @@ -244,18 +248,15 @@ def _on_run_error(self, agent: Any, exc: Exception) -> None: # Step handlers (registered as step_callbacks) # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_action_step") def _on_action_step(self, step: Any, agent: Any = None) -> None: - try: - self._handle_action_step(step, agent) - except Exception: - log.warning("layerlens: error in SmolAgents action step handler", exc_info=True) + self._handle_action_step(step, agent) + @resilient_callback(callback_name="_on_planning_step") def _on_planning_step(self, step: Any, agent: Any = None) -> None: - try: - self._handle_planning_step(step, agent) - except Exception: - log.warning("layerlens: error in SmolAgents planning step handler", exc_info=True) + self._handle_planning_step(step, agent) + @resilient_callback(callback_name="_on_final_answer_step") def _on_final_answer_step(self, step: Any, agent: Any = None) -> None: pass # run wrapper handles final output + flush diff --git a/src/layerlens/instrument/adapters/frameworks/strands.py b/src/layerlens/instrument/adapters/frameworks/strands.py index 25cec465..60168f4c 100644 --- a/src/layerlens/instrument/adapters/frameworks/strands.py +++ b/src/layerlens/instrument/adapters/frameworks/strands.py @@ -4,6 +4,7 @@ import logging from typing import Any, Dict, Optional +from .._base import resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -174,78 +175,71 @@ def _end_trace(self) -> None: # Hook handlers # ------------------------------------------------------------------ + @resilient_callback(callback_name="_on_agent_initialized") def _on_agent_initialized(self, event: Any) -> None: """Sync-only callback fired when an agent is constructed.""" - try: - agent = event.agent - name = _agent_name(agent) - self._emit_agent_config(name, agent) - except Exception: - log.warning("layerlens: error in Strands agent_initialized", exc_info=True) + agent = event.agent + name = _agent_name(agent) + self._emit_agent_config(name, agent) + @resilient_callback(callback_name="_on_before_invocation") def _on_before_invocation(self, event: Any) -> None: - try: - agent = event.agent - name = _agent_name(agent) - span_id = self._new_span_id() - with self._lock: - self._collector = TraceCollector(self._client, self._config) - self._run_span_id = span_id - self._current_agent_name = name - self._tick("run") - - # Re-emit config if we haven't seen this agent yet - self._emit_agent_config(name, agent) - - payload = self._payload(agent_name=name) - model_id = _model_id(agent) - if model_id: - payload["model"] = model_id + agent = event.agent + name = _agent_name(agent) + span_id = self._new_span_id() + with self._lock: + self._collector = TraceCollector(self._client, self._config) + self._run_span_id = span_id + self._current_agent_name = name + self._tick("run") + + # Re-emit config if we haven't seen this agent yet + self._emit_agent_config(name, agent) - messages = getattr(event, "messages", None) - self._set_if_capturing(payload, "input", safe_serialize(messages)) - self._fire("agent.input", payload, span_id=span_id, span_name=name) - except Exception: - log.warning("layerlens: error in Strands before_invocation", exc_info=True) + payload = self._payload(agent_name=name) + model_id = _model_id(agent) + if model_id: + payload["model"] = model_id + + messages = getattr(event, "messages", None) + self._set_if_capturing(payload, "input", safe_serialize(messages)) + self._fire("agent.input", payload, span_id=span_id, span_name=name) + @resilient_callback(callback_name="_on_after_invocation") def _on_after_invocation(self, event: Any) -> None: - try: - agent = event.agent - name = _agent_name(agent) - latency_ms = self._tock("run") - span_id = self._run_span_id or self._new_span_id() - - payload = self._payload(agent_name=name) - if latency_ms is not None: - payload["duration_ns"] = int(latency_ms * 1_000_000) - - result = getattr(event, "result", None) - if result is not None: - stop_reason = getattr(result, "stop_reason", None) - if stop_reason: - payload["stop_reason"] = str(stop_reason) - - message = getattr(result, "message", None) - self._set_if_capturing(payload, "output", safe_serialize(message)) - - # Emit per-cycle cost.record events matched to model spans. - # accumulated_usage updates AFTER AfterModelCallEvent fires, - # so we read per-cycle tokens here instead. - self._emit_per_cycle_tokens(agent) - - self._fire("agent.output", payload, span_id=span_id, span_name=name) - self._end_trace() - except Exception: - log.warning("layerlens: error in Strands after_invocation", exc_info=True) + agent = event.agent + name = _agent_name(agent) + latency_ms = self._tock("run") + span_id = self._run_span_id or self._new_span_id() + + payload = self._payload(agent_name=name) + if latency_ms is not None: + payload["duration_ns"] = int(latency_ms * 1_000_000) + + result = getattr(event, "result", None) + if result is not None: + stop_reason = getattr(result, "stop_reason", None) + if stop_reason: + payload["stop_reason"] = str(stop_reason) + message = getattr(result, "message", None) + self._set_if_capturing(payload, "output", safe_serialize(message)) + + # Emit per-cycle cost.record events matched to model spans. + # accumulated_usage updates AFTER AfterModelCallEvent fires, + # so we read per-cycle tokens here instead. + self._emit_per_cycle_tokens(agent) + + self._fire("agent.output", payload, span_id=span_id, span_name=name) + self._end_trace() + + @resilient_callback(callback_name="_on_before_model") def _on_before_model(self, event: Any) -> None: - try: - agent = event.agent - name = _agent_name(agent) - self._tick(f"model:{name}") - except Exception: - log.warning("layerlens: error in Strands before_model", exc_info=True) + agent = event.agent + name = _agent_name(agent) + self._tick(f"model:{name}") + @resilient_callback(callback_name="_on_after_model") def _on_after_model(self, event: Any) -> None: """Emit model.invoke with timing and error info. @@ -253,95 +247,88 @@ def _on_after_model(self, event: Any) -> None: accumulated_usage AFTER this hook fires. Tokens are emitted per-cycle from _on_after_invocation using the cycle data. """ - try: - agent = event.agent - name = _agent_name(agent) - latency_ms = self._tock(f"model:{name}") + agent = event.agent + name = _agent_name(agent) + latency_ms = self._tock(f"model:{name}") - model_id = _model_id(agent) - payload = self._payload() - if model_id: - payload["model"] = model_id - - if latency_ms is not None: - payload["latency_ms"] = latency_ms - - exception = getattr(event, "exception", None) - if exception is not None: - payload["error"] = str(exception) - payload["error_type"] = type(exception).__name__ - - stop_response = getattr(event, "stop_response", None) - if stop_response is not None: - stop_reason = getattr(stop_response, "stop_reason", None) - if stop_reason: - payload["stop_reason"] = str(stop_reason) - - parent = self._run_span_id - span_id = self._new_span_id() - self._fire("model.invoke", payload, span_id=span_id, parent_span_id=parent) - with self._lock: - self._model_span_ids.append(span_id) - except Exception: - log.warning("layerlens: error in Strands after_model", exc_info=True) + model_id = _model_id(agent) + payload = self._payload() + if model_id: + payload["model"] = model_id + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + exception = getattr(event, "exception", None) + if exception is not None: + payload["error"] = str(exception) + payload["error_type"] = type(exception).__name__ + + stop_response = getattr(event, "stop_response", None) + if stop_response is not None: + stop_reason = getattr(stop_response, "stop_reason", None) + if stop_reason: + payload["stop_reason"] = str(stop_reason) + + parent = self._run_span_id + span_id = self._new_span_id() + self._fire("model.invoke", payload, span_id=span_id, parent_span_id=parent) + with self._lock: + self._model_span_ids.append(span_id) + + @resilient_callback(callback_name="_on_before_tool") def _on_before_tool(self, event: Any) -> None: - try: - tool_use = event.tool_use - tool_name = ( - tool_use.get("name", "unknown") if isinstance(tool_use, dict) else getattr(tool_use, "name", "unknown") - ) - tool_id = ( - tool_use.get("toolUseId", tool_name) - if isinstance(tool_use, dict) - else getattr(tool_use, "toolUseId", tool_name) - ) - self._tick(f"tool:{tool_id}") - except Exception: - log.warning("layerlens: error in Strands before_tool", exc_info=True) + tool_use = event.tool_use + tool_name = ( + tool_use.get("name", "unknown") if isinstance(tool_use, dict) else getattr(tool_use, "name", "unknown") + ) + tool_id = ( + tool_use.get("toolUseId", tool_name) + if isinstance(tool_use, dict) + else getattr(tool_use, "toolUseId", tool_name) + ) + self._tick(f"tool:{tool_id}") + @resilient_callback(callback_name="_on_after_tool") def _on_after_tool(self, event: Any) -> None: - try: - tool_use = event.tool_use - tool_name = ( - tool_use.get("name", "unknown") if isinstance(tool_use, dict) else getattr(tool_use, "name", "unknown") - ) - tool_id = ( - tool_use.get("toolUseId", tool_name) - if isinstance(tool_use, dict) - else getattr(tool_use, "toolUseId", tool_name) - ) - tool_input = tool_use.get("input", None) if isinstance(tool_use, dict) else getattr(tool_use, "input", None) - latency_ms = self._tock(f"tool:{tool_id}") - - parent = self._run_span_id - span_id = self._new_span_id() - - call_payload = self._payload(tool_name=tool_name) - self._set_if_capturing(call_payload, "input", safe_serialize(tool_input)) - if latency_ms is not None: - call_payload["latency_ms"] = latency_ms - self._fire("tool.call", call_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}") - - result = getattr(event, "result", None) - result_payload = self._payload(tool_name=tool_name) - if result is not None: - status = result.get("status", None) if isinstance(result, dict) else getattr(result, "status", None) - if status: - result_payload["status"] = str(status) - content = result.get("content", None) if isinstance(result, dict) else getattr(result, "content", None) - self._set_if_capturing(result_payload, "output", safe_serialize(content)) - - exception = getattr(event, "exception", None) - if exception is not None: - result_payload["error"] = str(exception) - result_payload["error_type"] = type(exception).__name__ - - self._fire( - "tool.result", result_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}" - ) - except Exception: - log.warning("layerlens: error in Strands after_tool", exc_info=True) + tool_use = event.tool_use + tool_name = ( + tool_use.get("name", "unknown") if isinstance(tool_use, dict) else getattr(tool_use, "name", "unknown") + ) + tool_id = ( + tool_use.get("toolUseId", tool_name) + if isinstance(tool_use, dict) + else getattr(tool_use, "toolUseId", tool_name) + ) + tool_input = tool_use.get("input", None) if isinstance(tool_use, dict) else getattr(tool_use, "input", None) + latency_ms = self._tock(f"tool:{tool_id}") + + parent = self._run_span_id + span_id = self._new_span_id() + + call_payload = self._payload(tool_name=tool_name) + self._set_if_capturing(call_payload, "input", safe_serialize(tool_input)) + if latency_ms is not None: + call_payload["latency_ms"] = latency_ms + self._fire("tool.call", call_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}") + + result = getattr(event, "result", None) + result_payload = self._payload(tool_name=tool_name) + if result is not None: + status = result.get("status", None) if isinstance(result, dict) else getattr(result, "status", None) + if status: + result_payload["status"] = str(status) + content = result.get("content", None) if isinstance(result, dict) else getattr(result, "content", None) + self._set_if_capturing(result_payload, "output", safe_serialize(content)) + + exception = getattr(event, "exception", None) + if exception is not None: + result_payload["error"] = str(exception) + result_payload["error_type"] = type(exception).__name__ + + self._fire( + "tool.result", result_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}" + ) # ------------------------------------------------------------------ # Helpers diff --git a/tests/instrument/adapters/_base/__init__.py b/tests/instrument/adapters/_base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/instrument/adapters/_base/test_per_adapter_resilience.py b/tests/instrument/adapters/_base/test_per_adapter_resilience.py new file mode 100644 index 00000000..6ae9a57b --- /dev/null +++ b/tests/instrument/adapters/_base/test_per_adapter_resilience.py @@ -0,0 +1,537 @@ +"""Per-adapter resilience smoke tests for all 10 lighter framework adapters. + +These tests instantiate each adapter, force a callback to raise (by +sabotaging an inner helper), and assert: + +1. The exception does NOT propagate (framework would crash otherwise). +2. The resilience tracker recorded the failure. +3. After enough failures, ``adapter_info().metadata['resilience_status']`` + is ``"degraded"``. + +This is the per-adapter complement to ``test_resilience.py``, which +covers the decorator + tracker mechanics in isolation. + +Adapters covered (10 lighter): + agno, llamaindex, google_adk, strands, pydantic_ai, + smolagents, bedrock_agents, openai_agents, haystack, langfuse. + +Each adapter is exercised against a CALLBACK that exists on the +instance unconditionally (no need to construct framework-specific +fixture objects). The actual callback bodies use ``self._payload(...)``, +``self._fire(...)`` or similar internal helpers that we monkey-patch +to force a failure deterministically. +""" + +from __future__ import annotations + +from typing import Any, Dict +from unittest.mock import Mock + +import pytest + +from layerlens.instrument._context import ( + _current_run, + _current_span_id, + _current_collector, +) +from layerlens.instrument.adapters._base import DEFAULT_FAILURE_THRESHOLD + + +@pytest.fixture(autouse=True) +def _isolate_context_vars(): + """Ensure ContextVar state is clean before AND after every test. + + Several callbacks under test (pydantic_ai._on_before_run, + bedrock_agents._before_invoke) intentionally call _begin_run() — + when the test then forces those callbacks to fail, the ContextVar + tokens pushed by _begin_run are NOT popped (because the failure + happens after the push). Without per-test cleanup those leaked + tokens corrupt subsequent tests in the same process (notably + ``tests/instrument/test_trace_context.py``). + """ + # Snapshot current state (likely None) and force a clean baseline. + run_token = _current_run.set(None) + col_token = _current_collector.set(None) + span_token = _current_span_id.set(None) + try: + yield + finally: + # Hard reset — tests in this module are not expected to leave + # any persistent run/collector/span state. + for var, token in ( + (_current_run, run_token), + (_current_collector, col_token), + (_current_span_id, span_token), + ): + try: + var.reset(token) + except (ValueError, LookupError): + var.set(None) + + +class _Boom(Exception): + """Sentinel exception type used to verify the right error was caught.""" + + +def _force_payload_failure(adapter: Any) -> None: + """Sabotage ``adapter._payload`` so any callback that touches it raises.""" + + def _raise(*args: Any, **kwargs: Any) -> Dict[str, Any]: + raise _Boom("simulated framework callback failure") + + adapter._payload = _raise # type: ignore[method-assign] + + +def _force_fire_failure(adapter: Any) -> None: + """Sabotage ``adapter._fire`` for adapters whose callbacks call _fire directly.""" + + def _raise(*args: Any, **kwargs: Any) -> None: + raise _Boom("simulated _fire failure") + + adapter._fire = _raise # type: ignore[method-assign] + + +def _force_emit_failure(adapter: Any) -> None: + """Sabotage ``adapter._emit`` for adapters whose callbacks call _emit directly.""" + + def _raise(*args: Any, **kwargs: Any) -> None: + raise _Boom("simulated _emit failure") + + adapter._emit = _raise # type: ignore[method-assign] + + +# --------------------------------------------------------------------------- +# agno +# --------------------------------------------------------------------------- + + +class TestAgnoResilience: + def test_on_run_start_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + + adapter = AgnoAdapter(Mock()) + _force_payload_failure(adapter) + # Must not raise. + result = adapter._on_run_start(Mock(), "input") + assert result is None + assert adapter._resilience.total_failures == 1 + + def test_on_run_end_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + + adapter = AgnoAdapter(Mock()) + _force_payload_failure(adapter) + result = adapter._on_run_end(Mock(), Mock(), None) + assert result is None + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# llamaindex +# --------------------------------------------------------------------------- + + +class TestLlamaIndexResilience: + def test_on_query_start_failure_caught(self) -> None: + pytest.importorskip("llama_index.core") + from layerlens.instrument.adapters.frameworks.llamaindex import LlamaIndexAdapter + + adapter = LlamaIndexAdapter(Mock()) + _force_payload_failure(adapter) + event = Mock() + event.span_id = "span-1" + result = adapter._on_query_start(event) + assert result is None + assert adapter._resilience.total_failures == 1 + + def test_handle_event_unknown_type_no_op(self) -> None: + pytest.importorskip("llama_index.core") + from layerlens.instrument.adapters.frameworks.llamaindex import LlamaIndexAdapter + + adapter = LlamaIndexAdapter(Mock()) + # Unknown event class — handler lookup returns None, no exception. + adapter._handle_event(object()) + # No failure recorded — unknown types are a no-op, not an error. + assert adapter._resilience.total_failures == 0 + + def test_on_span_enter_failure_caught(self) -> None: + pytest.importorskip("llama_index.core") + from layerlens.instrument.adapters.frameworks.llamaindex import LlamaIndexAdapter + + adapter = LlamaIndexAdapter(Mock()) + + # Sabotage open_spans dict access by replacing _open_spans with + # an object that raises on __setitem__. + class _Bad: + def __setitem__(self, key: Any, value: Any) -> None: + raise _Boom("dict broken") + + def get(self, key: Any, default: Any = None) -> Any: + return default + + def __contains__(self, key: Any) -> bool: + return False + + adapter._open_spans = _Bad() # type: ignore[assignment] + result = adapter._on_span_enter("id-1", None) + # Default for span lifecycle is None — LlamaIndex tolerates it. + assert result is None + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# google_adk +# --------------------------------------------------------------------------- + + +class TestGoogleAdkResilience: + def test_on_before_run_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.google_adk import GoogleADKAdapter + + adapter = GoogleADKAdapter(Mock()) + _force_payload_failure(adapter) + adapter._on_before_run(Mock()) + assert adapter._resilience.total_failures == 1 + + def test_on_after_run_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.google_adk import GoogleADKAdapter + + adapter = GoogleADKAdapter(Mock()) + _force_payload_failure(adapter) + adapter._on_after_run(Mock()) + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# strands +# --------------------------------------------------------------------------- + + +class TestStrandsResilience: + def test_on_before_invocation_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.strands import StrandsAdapter + + adapter = StrandsAdapter(Mock()) + _force_payload_failure(adapter) + # Build a minimal event shim — the wrapped callback will raise + # when it tries to call _payload, which our sabotage replaces. + event = Mock() + event.agent = Mock() + adapter._on_before_invocation(event) + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# pydantic_ai +# --------------------------------------------------------------------------- + + +class TestPydanticAiResilience: + def test_on_before_run_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter + + adapter = PydanticAIAdapter(Mock()) + _force_payload_failure(adapter) + # _on_before_run calls _begin_run() then _payload — the latter + # raises and must be caught. + adapter._on_before_run(Mock()) + assert adapter._resilience.total_failures == 1 + + def test_on_after_model_request_passthrough_returns_response(self) -> None: + # Critical: when _on_after_model_request raises, the wrapper + # MUST return the original response object (passthrough_arg= + # "response") otherwise the agent's LLM response becomes None + # and the agent crashes downstream. + from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter + + adapter = PydanticAIAdapter(Mock()) + _force_emit_failure(adapter) + sentinel_response = Mock(name="response_object") + result = adapter._on_after_model_request( + Mock(), + request_context=Mock(), + response=sentinel_response, + ) + assert result is sentinel_response + assert adapter._resilience.total_failures == 1 + + def test_on_before_tool_execute_passthrough_returns_args(self) -> None: + from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter + + adapter = PydanticAIAdapter(Mock()) + + # Sabotage _start_timer (used inside _on_before_tool_execute). + def _raise(*a: Any, **kw: Any) -> None: + raise _Boom("timer broken") + + adapter._start_timer = _raise # type: ignore[method-assign] + sentinel_args = ("a", "b", "c") + result = adapter._on_before_tool_execute( + Mock(), + call=Mock(), + tool_def=Mock(), + args=sentinel_args, + ) + assert result == sentinel_args + assert adapter._resilience.total_failures == 1 + + def test_run_error_re_raises_framework_error(self) -> None: + # The error-callback path MUST always re-raise the framework's + # original error — even when our telemetry helper raises. The + # framework's contract requires the error to propagate. + from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter + + adapter = PydanticAIAdapter(Mock()) + _force_emit_failure(adapter) # telemetry will fail + + framework_error = ValueError("framework's own error") + with pytest.raises(ValueError, match="framework's own error"): + adapter._on_run_error(Mock(), error=framework_error) + + # Telemetry helper failure was caught + recorded; re-raise still happened. + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# smolagents +# --------------------------------------------------------------------------- + + +class TestSmolAgentsResilience: + def test_on_action_step_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.smolagents import SmolAgentsAdapter + + adapter = SmolAgentsAdapter(Mock()) + + # Sabotage _handle_action_step itself. + def _raise(*a: Any, **kw: Any) -> None: + raise _Boom("handler broken") + + adapter._handle_action_step = _raise # type: ignore[method-assign] + adapter._on_action_step(Mock(), Mock()) + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# bedrock_agents +# --------------------------------------------------------------------------- + + +class TestBedrockAgentsResilience: + def test_before_invoke_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.bedrock_agents import BedrockAgentsAdapter + + adapter = BedrockAgentsAdapter(Mock()) + adapter._connected = True # bypass the early ``not connected`` return + _force_payload_failure(adapter) + # boto3 invokes hooks with kwargs only. + adapter._before_invoke(params={"agentId": "id-1", "sessionId": "s-1"}) + assert adapter._resilience.total_failures == 1 + + def test_after_invoke_finally_runs_even_on_failure(self) -> None: + # The outer _after_invoke wraps the inner body in try/finally so + # _end_run() always fires — critical for releasing ContextVars. + from layerlens.instrument.adapters.frameworks.bedrock_agents import BedrockAgentsAdapter + + adapter = BedrockAgentsAdapter(Mock()) + adapter._connected = True + # Set up a run scope so _end_run has something to clean up. + adapter._begin_run() + _force_payload_failure(adapter) + + end_run_called = [] + original_end_run = adapter._end_run + + def _spy_end_run() -> None: + end_run_called.append(True) + original_end_run() + + adapter._end_run = _spy_end_run # type: ignore[method-assign] + adapter._after_invoke(parsed={"sessionId": "s-1"}) + # _end_run fired despite the body raising. + assert end_run_called == [True] + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# openai_agents +# --------------------------------------------------------------------------- + + +class TestOpenAiAgentsResilience: + def test_on_trace_start_failure_caught(self) -> None: + # If the SDK isn't installed, we still want to test the resilience + # wiring — but the class can't be instantiated because the parent + # TracingProcessor isn't available. Skip in that case. + pytest.importorskip("agents") + from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter + + adapter = OpenAIAgentsAdapter(Mock()) + + # Sabotage RunState construction by patching the lock to raise. + def _raise_on_acquire(*a: Any, **kw: Any) -> Any: + raise _Boom("lock broken") + + adapter._lock = Mock() + adapter._lock.__enter__ = _raise_on_acquire + adapter._lock.__exit__ = lambda *a: None + + trace = Mock() + trace.trace_id = "t-1" + adapter.on_trace_start(trace) + assert adapter._resilience.total_failures == 1 + + def test_on_trace_end_failure_caught(self) -> None: + pytest.importorskip("agents") + from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter + + adapter = OpenAIAgentsAdapter(Mock()) + + # Sabotage the trace_runs dict. + class _Bad: + def pop(self, *a: Any, **kw: Any) -> Any: + raise _Boom("dict broken") + + adapter._trace_runs = _Bad() # type: ignore[assignment] + trace = Mock() + trace.trace_id = "t-1" + adapter.on_trace_end(trace) + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# haystack +# --------------------------------------------------------------------------- + + +class TestHaystackResilience: + def test_on_span_end_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.haystack import ( + HaystackAdapter, + _LayerLensSpan, + ) + + adapter = HaystackAdapter(Mock()) + # Sabotage the _on_pipeline_end branch. + def _raise(*a: Any, **kw: Any) -> None: + raise _Boom("pipeline broken") + + adapter._on_pipeline_end = _raise # type: ignore[method-assign] + + span = _LayerLensSpan( + adapter, + "haystack.pipeline.run", + "span-1", + None, + {}, + is_pipeline=True, + ) + adapter._on_span_end(span) + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# langfuse +# --------------------------------------------------------------------------- + + +class TestLangfuseResilience: + def test_import_observation_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.langfuse import LangfuseAdapter + + adapter = LangfuseAdapter(Mock()) + + # Force the inner branch to raise — _import_generation is called + # for type=GENERATION, _import_span for SPAN; sabotage the SPAN + # path with a malformed obs. + collector = Mock() + bad_obs = {"type": "SPAN", "id": "obs-1"} + + # Sabotage _import_span specifically. + def _raise(*a: Any, **kw: Any) -> None: + raise _Boom("span import broken") + + adapter._import_span = _raise # type: ignore[method-assign] + adapter._import_observation(collector, bad_obs, "root-span") + assert adapter._resilience.total_failures == 1 + + def test_import_score_failure_caught(self) -> None: + from layerlens.instrument.adapters.frameworks.langfuse import LangfuseAdapter + + adapter = LangfuseAdapter(Mock()) + collector = Mock() + # collector.emit raises — our score importer must catch. + collector.emit.side_effect = _Boom("collector broken") + adapter._import_score( + collector, + {"sessionId": "s-1"}, + "trace-1", + "root-span", + {"name": "quality", "value": 0.9}, + ) + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# Health degradation across all 10 adapters +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "module_path, class_name", + [ + ("layerlens.instrument.adapters.frameworks.agno", "AgnoAdapter"), + ("layerlens.instrument.adapters.frameworks.smolagents", "SmolAgentsAdapter"), + ("layerlens.instrument.adapters.frameworks.google_adk", "GoogleADKAdapter"), + ("layerlens.instrument.adapters.frameworks.strands", "StrandsAdapter"), + ("layerlens.instrument.adapters.frameworks.pydantic_ai", "PydanticAIAdapter"), + ("layerlens.instrument.adapters.frameworks.bedrock_agents", "BedrockAgentsAdapter"), + ("layerlens.instrument.adapters.frameworks.haystack", "HaystackAdapter"), + ("layerlens.instrument.adapters.frameworks.langfuse", "LangfuseAdapter"), + ], +) +def test_adapter_health_degrades_on_repeated_failures(module_path: str, class_name: str) -> None: + """Every lighter adapter exposes resilience health via adapter_info().metadata.""" + import importlib + + module = importlib.import_module(module_path) + adapter_cls = getattr(module, class_name) + adapter = adapter_cls(Mock()) + + # Hit the threshold by recording failures directly on the tracker + # (faster than driving each adapter's specific callback path; this + # test is purely about the metadata surface). + for _ in range(DEFAULT_FAILURE_THRESHOLD): + adapter._resilience.record_failure("synthetic", _Boom("threshold test")) + + info = adapter.adapter_info() + assert info.metadata["resilience_status"] == "degraded" + assert info.metadata["resilience_failures_total"] == DEFAULT_FAILURE_THRESHOLD + + +# --------------------------------------------------------------------------- +# llamaindex / openai_agents handled separately because they need +# llama_index_core / agents installed at test time. Use importorskip. +# --------------------------------------------------------------------------- + + +def test_llamaindex_health_degrades_on_repeated_failures() -> None: + pytest.importorskip("llama_index.core") + from layerlens.instrument.adapters.frameworks.llamaindex import LlamaIndexAdapter + + adapter = LlamaIndexAdapter(Mock()) + for _ in range(DEFAULT_FAILURE_THRESHOLD): + adapter._resilience.record_failure("synthetic", _Boom("threshold")) + info = adapter.adapter_info() + assert info.metadata["resilience_status"] == "degraded" + + +def test_openai_agents_health_degrades_on_repeated_failures() -> None: + pytest.importorskip("agents") + from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter + + adapter = OpenAIAgentsAdapter(Mock()) + for _ in range(DEFAULT_FAILURE_THRESHOLD): + adapter._resilience.record_failure("synthetic", _Boom("threshold")) + info = adapter.adapter_info() + assert info.metadata["resilience_status"] == "degraded" diff --git a/tests/instrument/adapters/_base/test_resilience.py b/tests/instrument/adapters/_base/test_resilience.py new file mode 100644 index 00000000..5c1f730c --- /dev/null +++ b/tests/instrument/adapters/_base/test_resilience.py @@ -0,0 +1,500 @@ +"""Tests for the per-callback resilience wrapper. + +Covers ``ResilienceTracker``, ``resilient_callback``, ``get_default_for``, +``HealthStatus``, and the ``adapter_info().metadata`` integration on +``FrameworkAdapter`` subclasses. + +Every test asserts a behaviour that prevents observability code from +breaking the framework's own execution path. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any, Dict +from unittest.mock import Mock + +import pytest + +from layerlens.instrument.adapters._base import ( + DEFAULT_FAILURE_THRESHOLD, + AdapterInfo, + BaseAdapter, + HealthStatus, + ResilienceTracker, + get_default_for, + resilient_callback, +) +from layerlens.instrument.adapters.frameworks._base_framework import FrameworkAdapter + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +class _Boom(Exception): + """Sentinel error type so tests can assert the right exception was caught.""" + + +class _DummyAdapter: + """Minimal adapter shape — provides ``name`` and ``_resilience`` only.""" + + name = "dummy" + + def __init__(self) -> None: + self._resilience = ResilienceTracker(self.name) + + @resilient_callback(callback_name="my_callback", default="DEFAULT") + def my_callback(self, value: Any) -> Any: + if value == "raise": + raise _Boom("dummy failure") + return f"ok:{value}" + + @resilient_callback(callback_name="passthrough_cb", passthrough_arg="value") + def passthrough_cb(self, value: Any) -> Any: + if value == "raise": + raise _Boom("passthrough failure") + return f"transformed:{value}" + + @resilient_callback(callback_name="kw_passthrough", passthrough_arg="payload") + def kw_passthrough(self, *, payload: Any) -> Any: + if payload == "raise": + raise _Boom("kw passthrough failure") + return {"wrapped": payload} + + +class _MinimalFramework(FrameworkAdapter): + """Real FrameworkAdapter subclass for integration tests.""" + + name = "test-framework" + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + return None + + @resilient_callback(callback_name="emit_thing") + def emit_thing(self, value: int) -> None: + if value < 0: + raise _Boom(f"negative value: {value}") + + +# --------------------------------------------------------------------------- +# get_default_for +# --------------------------------------------------------------------------- + + +class TestGetDefaultFor: + def test_known_callback_returns_none(self) -> None: + # All registered callbacks default to None — the framework default + # for void-callback APIs (Strands hooks, Google ADK plugins, boto3). + assert get_default_for("on_trace_start") is None + assert get_default_for("_before_invoke") is None + assert get_default_for("after_run_callback") is None + + def test_unknown_callback_returns_none(self) -> None: + # Unknown callback names also return None — the safe default. If a + # callback needs a non-None default, the adapter must pass it + # explicitly via @resilient_callback(default=...). + assert get_default_for("does_not_exist") is None + assert get_default_for("") is None + + +# --------------------------------------------------------------------------- +# ResilienceTracker +# --------------------------------------------------------------------------- + + +class TestResilienceTracker: + def test_starts_healthy_with_zero_failures(self) -> None: + tracker = ResilienceTracker("test") + assert tracker.total_failures == 0 + assert tracker.health_status() == HealthStatus.HEALTHY + + def test_threshold_validation(self) -> None: + with pytest.raises(ValueError): + ResilienceTracker("test", threshold=0) + with pytest.raises(ValueError): + ResilienceTracker("test", threshold=-1) + + def test_record_failure_increments_counter(self) -> None: + tracker = ResilienceTracker("test") + tracker.record_failure("cb1", _Boom("first")) + assert tracker.total_failures == 1 + tracker.record_failure("cb1", _Boom("second")) + tracker.record_failure("cb2", _Boom("third")) + assert tracker.total_failures == 3 + + def test_health_degrades_after_threshold(self) -> None: + tracker = ResilienceTracker("test", threshold=3) + tracker.record_failure("cb", _Boom("a")) + tracker.record_failure("cb", _Boom("b")) + assert tracker.health_status() == HealthStatus.HEALTHY + tracker.record_failure("cb", _Boom("c")) + assert tracker.health_status() == HealthStatus.DEGRADED + + def test_metadata_snapshot(self) -> None: + tracker = ResilienceTracker("test", threshold=2) + tracker.record_failure("cb1", _Boom("oops")) + tracker.record_failure("cb2", ValueError("bad value")) + snap = tracker.as_metadata() + assert snap["resilience_status"] == HealthStatus.DEGRADED.value + assert snap["resilience_failures_total"] == 2 + assert snap["resilience_failure_threshold"] == 2 + # Per-callback breakdown carries the top failure counts. + assert snap["resilience_failures_by_callback"] == {"cb1": 1, "cb2": 1} + # Last error preserved (truncated) for triage. + assert snap["resilience_last_callback"] == "cb2" + assert "ValueError" in snap["resilience_last_error"] + + def test_reset_clears_state(self) -> None: + tracker = ResilienceTracker("test") + tracker.record_failure("cb", _Boom("x")) + assert tracker.total_failures == 1 + tracker.reset() + assert tracker.total_failures == 0 + assert tracker.health_status() == HealthStatus.HEALTHY + snap = tracker.as_metadata() + assert snap["resilience_failures_total"] == 0 + assert "resilience_last_error" not in snap + + def test_thread_safety(self) -> None: + # Many threads recording failures concurrently must not lose any + # increments — observability code commonly fires from worker + # threads (CrewAI, AutoGen group chat, Bedrock boto3 hooks). + tracker = ResilienceTracker("test", threshold=DEFAULT_FAILURE_THRESHOLD) + per_thread_count = 100 + thread_count = 8 + + def _worker() -> None: + for _ in range(per_thread_count): + tracker.record_failure("cb", _Boom("concurrent")) + + threads = [threading.Thread(target=_worker) for _ in range(thread_count)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert tracker.total_failures == per_thread_count * thread_count + + +# --------------------------------------------------------------------------- +# resilient_callback decorator +# --------------------------------------------------------------------------- + + +class TestResilientCallbackDecorator: + def test_returns_normal_value_on_success(self) -> None: + adapter = _DummyAdapter() + assert adapter.my_callback("hello") == "ok:hello" + assert adapter._resilience.total_failures == 0 + + def test_returns_default_on_exception(self) -> None: + adapter = _DummyAdapter() + # Without the wrapper, the call would raise — instead it returns + # the framework's expected default and the framework continues. + result = adapter.my_callback("raise") + assert result == "DEFAULT" + + def test_failure_counter_incremented(self) -> None: + adapter = _DummyAdapter() + adapter.my_callback("raise") + adapter.my_callback("raise") + adapter.my_callback("ok-1") # this one succeeds + adapter.my_callback("raise") + assert adapter._resilience.total_failures == 3 + + def test_exception_is_logged_with_context(self, caplog: pytest.LogCaptureFixture) -> None: + adapter = _DummyAdapter() + with caplog.at_level(logging.WARNING): + adapter.my_callback("raise") + + # Adapter name + callback name + traceback all surfaced. + assert any( + "dummy" in rec.message and "my_callback" in rec.message and "_Boom" in rec.message + for rec in caplog.records + ) + + def test_exception_does_not_propagate(self) -> None: + adapter = _DummyAdapter() + # The whole point: the framework calling this method MUST NOT see + # _Boom — that would crash the user's agent. + try: + adapter.my_callback("raise") + except _Boom: + pytest.fail("resilient_callback let the exception escape") + + def test_passthrough_arg_returns_positional_value(self) -> None: + adapter = _DummyAdapter() + # On failure the wrapper returns the passthrough arg's value + # rather than the default — critical for mutating hooks + # (Pydantic-AI ``after_model_request`` returns ``response``). + assert adapter.passthrough_cb("raise") == "raise" + # On success, the original return value flows through. + assert adapter.passthrough_cb("ok") == "transformed:ok" + + def test_passthrough_arg_returns_keyword_value(self) -> None: + adapter = _DummyAdapter() + assert adapter.kw_passthrough(payload="raise") == "raise" + assert adapter.kw_passthrough(payload="data") == {"wrapped": "data"} + + def test_health_degrades_after_repeated_failures(self) -> None: + adapter = _DummyAdapter() + # Default threshold is 5; after 5 consecutive failures the + # adapter reports DEGRADED so monitoring can alert. + assert adapter._resilience.health_status() == HealthStatus.HEALTHY + for _ in range(DEFAULT_FAILURE_THRESHOLD): + adapter.my_callback("raise") + assert adapter._resilience.health_status() == HealthStatus.DEGRADED + + def test_keyboard_interrupt_propagates(self) -> None: + # We must NEVER swallow KeyboardInterrupt / SystemExit / + # GeneratorExit — those are control-flow signals, not bugs. + class _CtrlCAdapter: + name = "ctrlc" + + def __init__(self) -> None: + self._resilience = ResilienceTracker(self.name) + + @resilient_callback(callback_name="cb") + def cb(self) -> None: + raise KeyboardInterrupt("user pressed Ctrl-C") + + adapter = _CtrlCAdapter() + with pytest.raises(KeyboardInterrupt): + adapter.cb() + + def test_works_without_resilience_tracker_attribute( + self, caplog: pytest.LogCaptureFixture + ) -> None: + # If an adapter forgets to set up _resilience, the wrapper still + # logs and returns the default — never crashes the framework. + class _NoTracker: + name = "no_tracker" + + @resilient_callback(callback_name="cb", default="OK") + def cb(self) -> str: + raise _Boom("no tracker") + + adapter = _NoTracker() + with caplog.at_level(logging.WARNING): + assert adapter.cb() == "OK" + assert any("_Boom" in rec.message for rec in caplog.records) + + def test_logger_uses_module_of_decorated_function( + self, caplog: pytest.LogCaptureFixture + ) -> None: + # Failures are logged via the wrapped function's module logger so + # users can mute one adapter's resilience warnings without + # silencing all of them. + adapter = _DummyAdapter() + with caplog.at_level(logging.WARNING, logger=__name__): + adapter.my_callback("raise") + # Our test module's logger captured the warning. + assert any(rec.name == __name__ for rec in caplog.records) + + +# --------------------------------------------------------------------------- +# Integration with FrameworkAdapter +# --------------------------------------------------------------------------- + + +class TestFrameworkAdapterIntegration: + def test_framework_adapter_owns_resilience_tracker(self) -> None: + adapter = _MinimalFramework(Mock()) + assert isinstance(adapter._resilience, ResilienceTracker) + assert adapter._resilience.total_failures == 0 + + def test_adapter_info_surfaces_resilience_metadata(self) -> None: + adapter = _MinimalFramework(Mock()) + info: AdapterInfo = adapter.adapter_info() + meta = info.metadata + assert meta["resilience_status"] == "healthy" + assert meta["resilience_failures_total"] == 0 + assert meta["resilience_failure_threshold"] == DEFAULT_FAILURE_THRESHOLD + + def test_adapter_info_reports_degraded_after_failures(self) -> None: + adapter = _MinimalFramework(Mock()) + for _ in range(DEFAULT_FAILURE_THRESHOLD): + adapter.emit_thing(-1) # raises inside the wrapped method + info = adapter.adapter_info() + assert info.metadata["resilience_status"] == "degraded" + assert info.metadata["resilience_failures_total"] == DEFAULT_FAILURE_THRESHOLD + + def test_disconnect_resets_resilience(self) -> None: + adapter = _MinimalFramework(Mock()) + adapter.connect() + adapter.emit_thing(-1) + adapter.emit_thing(-2) + assert adapter._resilience.total_failures == 2 + adapter.disconnect() + assert adapter._resilience.total_failures == 0 + + def test_callback_failure_does_not_break_framework(self) -> None: + adapter = _MinimalFramework(Mock()) + adapter.connect() + # Simulating "framework fires our callback" — the callback throws + # but the framework's call-site sees no exception, just None. + result = adapter.emit_thing(-99) + assert result is None + assert adapter._resilience.total_failures == 1 + + +# --------------------------------------------------------------------------- +# Public surface re-exports +# --------------------------------------------------------------------------- + + +class TestPackageExports: + def test_base_package_re_exports_resilience_helpers(self) -> None: + from layerlens.instrument.adapters._base import ( + DEFAULT_FAILURE_THRESHOLD as T1, + AdapterInfo as A1, + BaseAdapter as B1, + HealthStatus as H1, + ResilienceTracker as R1, + get_default_for as G1, + resilient_callback as RC1, + ) + + # Sanity — every public symbol resolves and is the right kind. + assert A1 is AdapterInfo + assert B1 is BaseAdapter + assert R1 is ResilienceTracker + assert RC1 is resilient_callback + assert H1 is HealthStatus + assert T1 == DEFAULT_FAILURE_THRESHOLD + assert G1 is get_default_for + + +# --------------------------------------------------------------------------- +# Decorator preserves function metadata +# --------------------------------------------------------------------------- + + +class TestDecoratorMetadata: + def test_wrapped_function_keeps_name_and_docstring(self) -> None: + class _A: + name = "x" + _resilience = ResilienceTracker("x") + + @resilient_callback(callback_name="cb") + def cb(self) -> None: + """My docstring.""" + pass + + # functools.wraps preserves __name__ and __doc__ — important for + # frameworks that introspect handlers by name (boto3 event system + # uses handler identity for unregister()). + assert _A.cb.__name__ == "cb" + assert _A.cb.__doc__ == "My docstring." + + +# --------------------------------------------------------------------------- +# End-to-end: verifying the per-adapter failure scenario +# --------------------------------------------------------------------------- + + +class TestPerAdapterCallbackException: + """Simulate a callback exception on each lighter adapter and assert + the framework continues unaffected. + + Each test instantiates the adapter, monkey-patches one of its + callback methods to raise, then invokes the callback and asserts: + + 1. No exception escaped (framework would crash otherwise). + 2. The resilience tracker incremented its failure counter. + 3. Repeated failures cross the threshold and degrade adapter health. + """ + + @pytest.mark.parametrize( + "module_path, class_name, callback_name, callback_args", + [ + ( + "layerlens.instrument.adapters.frameworks.agno", + "AgnoAdapter", + "_on_run_start", + (Mock(), "input"), + ), + ( + "layerlens.instrument.adapters.frameworks.agno", + "AgnoAdapter", + "_on_run_end", + (Mock(), Mock(), None), + ), + ( + "layerlens.instrument.adapters.frameworks.smolagents", + "SmolAgentsAdapter", + "_on_run_start", + (Mock(), "task"), + ), + ( + "layerlens.instrument.adapters.frameworks.smolagents", + "SmolAgentsAdapter", + "_on_run_end", + (Mock(), Mock(), None), + ), + ( + "layerlens.instrument.adapters.frameworks.smolagents", + "SmolAgentsAdapter", + "_on_run_error", + (Mock(), _Boom("framework error")), + ), + ( + "layerlens.instrument.adapters.frameworks.smolagents", + "SmolAgentsAdapter", + "_on_action_step", + (Mock(), Mock()), + ), + ], + ) + def test_callback_exception_caught_and_counted( + self, + module_path: str, + class_name: str, + callback_name: str, + callback_args: tuple, + ) -> None: + import importlib + + module = importlib.import_module(module_path) + adapter_cls = getattr(module, class_name) + adapter = adapter_cls(Mock()) + + # Force the underlying body to raise by sabotaging an inner + # helper the callback always calls. The simplest way is to patch + # ``adapter._payload`` to raise — every callback uses it. + original_payload = adapter._payload + + def _raise_on_payload(*args: Any, **kwargs: Any) -> Dict[str, Any]: + raise _Boom("simulated callback failure") + + adapter._payload = _raise_on_payload # type: ignore[method-assign] + + try: + cb = getattr(adapter, callback_name) + # Must not raise — that's the entire resilience contract. + cb(*callback_args) + # Failure recorded against this exact callback name. + assert adapter._resilience.total_failures >= 1 + finally: + adapter._payload = original_payload # type: ignore[method-assign] + + def test_repeated_failures_degrade_adapter(self) -> None: + # Use agno as the proxy — same wiring applies to all 10 lighter + # adapters because they all inherit from FrameworkAdapter and use + # @resilient_callback on their entry points. + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + + adapter = AgnoAdapter(Mock()) + + def _raise_on_payload(*args: Any, **kwargs: Any) -> Dict[str, Any]: + raise _Boom("persistent failure") + + adapter._payload = _raise_on_payload # type: ignore[method-assign] + + for _ in range(DEFAULT_FAILURE_THRESHOLD): + adapter._on_run_start(Mock(), "input") + + info = adapter.adapter_info() + assert info.metadata["resilience_status"] == "degraded" diff --git a/tests/instrument/adapters/frameworks/test_langfuse.py b/tests/instrument/adapters/frameworks/test_langfuse.py index 86c50ff2..70aca581 100644 --- a/tests/instrument/adapters/frameworks/test_langfuse.py +++ b/tests/instrument/adapters/frameworks/test_langfuse.py @@ -222,13 +222,21 @@ def test_adapter_info_returns_correct_metadata(self, connected_adapter): assert info.name == "langfuse" assert info.adapter_type == "framework" assert info.connected is True - assert info.metadata == {"host": "https://test.langfuse.com"} + # The Langfuse-specific ``host`` metadata must be present; resilience + # health metadata is added by FrameworkAdapter.adapter_info() to + # every framework adapter — assert presence of both surfaces. + assert info.metadata["host"] == "https://test.langfuse.com" + assert info.metadata["resilience_status"] == "healthy" def test_adapter_info_disconnected(self, mock_client): adapter = LangfuseAdapter(mock_client) info = adapter.adapter_info() assert info.connected is False - assert info.metadata == {} + # Disconnected adapters expose only the resilience health surface + # (no per-adapter metadata since connect() never populated it). + assert info.metadata.get("host") is None + assert info.metadata["resilience_status"] == "healthy" + assert info.metadata["resilience_failures_total"] == 0 # =================================================================== From fec3bed264d11ed25b14599ec464cc5c1d6b161d Mon Sep 17 00:00:00 2001 From: mmercuri Date: Sun, 26 Apr 2026 20:45:45 -0700 Subject: [PATCH 2/2] feat(instrument): State include/exclude key filters for 6 multi-agent adapters (cross-poll #6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements cross-pollination item #6 from `A:/tmp/adapter-cross-pollination-audit.md` §2.12. LangGraph's `LangGraphStateAdapter` (mature) supports include/exclude key filters at the state-snapshot level so customers can scrub sensitive state (api_keys, tokens, PII) WITHOUT modifying their agent code or doing post-hoc redaction. This PR brings the same contract to the lighter multi-agent framework adapters present on this base. ## What's new ### Shared filter module `src/layerlens/instrument/adapters/_base/state_filters.py` — new - `StateFilter` — frozen dataclass with `include_keys`, `exclude_keys`, `mask_keys`, `recursive`. Three operations applied in order: exclude (drop), mask (replace value with `[REDACTED]`), include (allowlist). Matching is case-insensitive substring after alphanumeric-only normalisation, so `X-Api-Key`, `USER_API_KEY`, `customer.email_address` all match without enumerating every variant. - `DEFAULT_PII_EXCLUDE_KEYS` — conservative denylist covering 49 common credential / PII / financial / contact field names. Customers who do nothing still get baseline protection out of the box per CLAUDE.md ("never silently leak customer data"). - `default_state_filter()` — factory installed by every adapter unless the caller passes a custom `state_filter`. - `filter_state(state, filter)` — pure function returning `(filtered_state, filtered_keys)` so adapters can surface the clipped key names as `_filtered_keys` event metadata for audit. - `filter_payload_fields(payload, filter, fields)` — surgical helper that filters only the named dict-shaped fields of a mixed-shape payload (so scalar metadata like `model`, `latency_ms` is preserved). - `StateFilter.permissive()` — opt-out factory for tests / explicit disablement. The active filter snapshot is surfaced under `adapter_info().metadata['state_filter']` so operators can detect accidental disablement. - `StateFilter.with_extra_excludes()` — default + caller's additions. ### FrameworkAdapter integration `src/layerlens/instrument/adapters/frameworks/_base_framework.py` - Constructor accepts optional `state_filter` (defaults to `default_state_filter()`). - `self._state_filter` reachable on every subclass. - New `_filter_payload(payload, *fields)` helper used by adapters immediately before each `_emit(...)` call for any payload that may contain user-controlled state. - New `serialize_state_filter_for_replay()` — replay engine uses this to reconstruct an equivalent filter on the other side, so the captured payload shapes match between original run and replay. - `adapter_info().metadata['state_filter']` surfaces the active config. ### Per-adapter wiring (6 multi-agent adapters) | Adapter | Constructor `state_filter` | Filter applied at emit | |-----------------|----------------------------|--------------------------------| | `agno` | YES | `agent.input/output`, `tool.call/result` | | `openai_agents` | YES | `agent.input` (tools/handoffs/output_type), generation messages, function args + parameters_schema + mcp_data, tool result | | `llamaindex` | YES | LLM messages + output_message, tool args, retrieval input/output, query input/output, agent_step input/output | | `google_adk` | YES | run user_content + agent_tree, agent user_content, tool args, tool result | | `strands` | YES | invocation messages, after-invocation output, tool input, tool result | | `pydantic_ai` | YES | agent input + deps_summary, agent output, tool.call args, tool.result output, streaming output_message | The audit §2.12 enumerates 7 targets including `ms_agent_framework` — that adapter doesn't exist on this branch's base (`feat/instrument-callback-resilience`); it lives only on the parallel `feat/instrument-multitenancy-org-id-propagation` history. It will be wired when the ms_agent_framework adapter is ported to this base or the histories merge. ## Tests (53 new + integration) ### `tests/instrument/adapters/_base/test_state_filters.py` — 53 tests - `TestStateFilterConstruction` — defaults are PII-aware, lowercasing, permissive factory, with_extra_excludes factory. - `TestStateFilterMetadata` — default snapshot shape, allowlist surfaces in metadata. - `TestFilterStateExclude` — default PII keys removed, vendor variants caught (`X-Api-Key`, `USER_API_KEY`, `stripe_customer_email`), permissive opt-out. - `TestFilterStateMask` — keeps key visible, masking runs before recurse so nested PII can't leak through a masked field. - `TestFilterStateInclude` — allowlist semantics, exclude wins over include when both match. - `TestFilterStateRecursive` — nested dicts, lists of dicts, non-recursive flag. - `TestFilterStatePassthrough` — primitives + empty dict pass through. - `TestFilterPayloadFields` — surgical filter (scalars untouched), missing fields skipped, scalar field is no-op, accumulating `_filtered_keys` across multiple passes. - `TestFrameworkAdapterStateFilterDefaults` — default installed, custom override, end-to-end PII drop, replay snapshot, adapter_info. - `TestPerAdapterStateFilterWiring` — parametrized across all 6 adapters: constructor accepts state_filter, default is PII-aware, state_filter surfaces in adapter_info. - `TestEndToEndAgnoFilter` — filter actually runs at the emit boundary (not just sits idle on the adapter). ### Existing test suites unchanged - `tests/instrument/adapters/_base/` — 110 passed, 7 skipped. - `tests/instrument/adapters/frameworks/` — 114 passed (langchain, langgraph, langfuse, agentforce — adapters with deps installed in CI venv), 12 skipped (optional deps), 1 pre-existing Windows clock-resolution flake on test_haystack (documented in PR #117). ## Documentation `docs/adapters/state-filters.md` — explains default behaviour, three filter operations, configuration recipes, recursion, auditability via `_filtered_keys`, replay reproducibility, and the per-adapter wiring matrix. ## Acceptance ``` pytest tests/instrument/adapters/_base/test_state_filters.py -x # 53 passed in 0.10s pytest tests/instrument/adapters/_base/ # 110 passed, 7 skipped in 0.26s pytest tests/instrument/adapters/frameworks/ # adapters with installed deps # 114 passed, 12 skipped, 1 pre-existing flake (test_haystack.test_input_and_output) mypy --strict src/layerlens/instrument/adapters/_base/state_filters.py # Success: no issues found in 1 source file mypy src/layerlens/instrument/adapters/frameworks/{_base_framework,agno,openai_agents,llamaindex,google_adk,strands,pydantic_ai}.py # Success: no issues found in 7 source files ruff check src/layerlens/instrument/adapters/_base/state_filters.py src/layerlens/instrument/adapters/frameworks/{_base_framework,agno,openai_agents,llamaindex,google_adk,strands,pydantic_ai}.py tests/instrument/adapters/_base/test_state_filters.py # All checks passed! ``` --- docs/adapters/state-filters.md | 153 ++++++ .../instrument/adapters/_base/__init__.py | 14 + .../adapters/_base/state_filters.py | 444 ++++++++++++++++++ .../adapters/frameworks/_base_framework.py | 73 ++- .../instrument/adapters/frameworks/agno.py | 15 +- .../adapters/frameworks/google_adk.py | 19 +- .../adapters/frameworks/llamaindex.py | 27 +- .../adapters/frameworks/openai_agents.py | 27 +- .../adapters/frameworks/pydantic_ai.py | 21 +- .../instrument/adapters/frameworks/strands.py | 15 +- .../adapters/_base/test_state_filters.py | 429 +++++++++++++++++ 11 files changed, 1214 insertions(+), 23 deletions(-) create mode 100644 docs/adapters/state-filters.md create mode 100644 src/layerlens/instrument/adapters/_base/state_filters.py create mode 100644 tests/instrument/adapters/_base/test_state_filters.py diff --git a/docs/adapters/state-filters.md b/docs/adapters/state-filters.md new file mode 100644 index 00000000..812507d2 --- /dev/null +++ b/docs/adapters/state-filters.md @@ -0,0 +1,153 @@ +# Adapter State Filters + +Framework adapters emit dict-shaped state into trace events +(`agent.input`, `agent.output`, `tool.call`, `tool.result`, +`model.invoke`, etc.). Without filtering, that state can carry +credentials, PII, or unbounded cardinality straight into telemetry +sinks. The state-filter subsystem is the last line of defence between +user state and the wire. + +## Default behaviour + +Every multi-agent framework adapter ships with a conservative default +filter — `default_state_filter()` — that excludes a denylist of common +PII and credential field names. Customers who do nothing still get +baseline protection out of the box. + +The default denylist (case-insensitive substring match) covers: + +- **Credentials**: `password`, `passwd`, `pwd`, `api_key`, `apikey`, + `api_secret`, `secret`, `secret_key`, `access_token`, + `refresh_token`, `auth_token`, `bearer_token`, `token`, + `session_token`, `cookie`, `cookies`, `private_key`, + `client_secret`, `service_account` +- **Personal identifiers**: `ssn`, `social_security`, + `social_security_number`, `tax_id`, `national_id`, `passport`, + `passport_number`, `drivers_license` +- **Financial**: `credit_card`, `credit_card_number`, `card_number`, + `cvv`, `cvc`, `iban`, `account_number`, `routing_number` +- **Contact / location**: `email`, `email_address`, `phone`, + `phone_number`, `address`, `street_address`, `home_address`, + `billing_address`, `shipping_address` +- **Authn material**: `authorization`, `x-api-key`, `set-cookie` + +Substring matching after non-alphanumeric normalisation means +`X-Api-Key`, `stripe_customer_email`, `USER_API_KEY`, and +`customer.email_address` all match without the caller having to +enumerate every variant. + +## Configuration + +Pass a `StateFilter` instance to the adapter constructor: + +```python +from layerlens.instrument.adapters._base import StateFilter +from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + +# Add custom keys to the default denylist +filter = StateFilter.with_extra_excludes(["internal_user_id", "session_attributes"]) +adapter = AgnoAdapter(client, state_filter=filter) +``` + +### Three filter operations + +A `StateFilter` declares three operations, applied in this order: + +1. **`exclude_keys`** — keys (case-insensitive substring match) + removed from the output entirely. Defaults to the PII denylist. +2. **`mask_keys`** — keys (case-insensitive substring match) whose + values are replaced with `[REDACTED]`. The key remains visible (so + dashboards see the field exists), but the value is hidden. Default + empty — opt-in. +3. **`include_keys`** — if non-empty, restricts the output to ONLY + these keys (case-insensitive equality). Acts as a strict + allowlist after exclude/mask have run. + +```python +filter = StateFilter( + exclude_keys=frozenset({"password", "ssn"}), # drop entirely + mask_keys=frozenset({"phone", "address"}), # show key, hide value + include_keys=frozenset({"name", "phone", "model"}), # allowlist +) +``` + +### Disabling the filter + +For tests or explicit opt-out, use `StateFilter.permissive()`: + +```python +adapter = AgnoAdapter(client, state_filter=StateFilter.permissive()) +``` + +This is **strongly discouraged** in production. The active filter is +surfaced in `adapter.adapter_info().metadata['state_filter']` so +operators can detect accidental disablement. + +## Recursion + +By default, the filter walks nested dicts (and dicts inside lists) +recursively, so a structure like: + +```python +{ + "messages": [ + {"role": "user", "content": "hi", "api_key": "sk-..."}, + ], +} +``` + +emits as: + +```python +{ + "messages": [ + {"role": "user", "content": "hi"}, + ], + "_filtered_keys": ["api_key"], +} +``` + +Set `recursive=False` to filter only the top-level dict. + +## Auditability + +Every event payload that has been touched by the filter carries a +`_filtered_keys` field listing the (lowercased) names of every key +that was excluded or masked anywhere in the payload. Operators can +correlate this with the active filter config (in +`adapter_info().metadata['state_filter']`) to verify exactly what was +clipped without exposing the values themselves. + +## Replay reproducibility + +The active filter snapshot is included under +`serialize_state_filter_for_replay()` so the replay engine can +reconstruct an equivalent `StateFilter` on the other side. Replays +must apply the SAME filter as the original run so the captured payload +shapes match. + +## Multi-agent adapters with state-filter wiring + +| Adapter | Status | Notes | +|----------------|----------|--------------------------------------------------------| +| `agno` | Wired | Filter applied on `agent.input`/`output`/`tool.*` | +| `openai_agents`| Wired | Filter on `agent.input`, `tool.call/result`, generation messages | +| `llamaindex` | Wired | Filter on LLM messages, retrieval, query, agent_step | +| `google_adk` | Wired | Filter on user_content (agent + tool input/output) | +| `strands` | Wired | Filter on invocation messages + tool input/output | +| `pydantic_ai` | Wired | Filter on agent input/output, deps_summary, tool args | + +For mature adapters (LangChain, LangGraph, CrewAI, AutoGen, +Agentforce, Semantic Kernel) state filtering is performed at the +framework-native layer (e.g. LangGraph's `LangGraphStateAdapter` +include/exclude keys) — the cross-pollination here brings the same +contract to the lighter multi-agent adapters. + +## Reference + +- Implementation: + `src/layerlens/instrument/adapters/_base/state_filters.py` +- Adapter-base wiring: + `src/layerlens/instrument/adapters/frameworks/_base_framework.py` +- Tests: `tests/instrument/adapters/_base/test_state_filters.py` +- See also: [Data Privacy](../security/data-privacy.md) diff --git a/src/layerlens/instrument/adapters/_base/__init__.py b/src/layerlens/instrument/adapters/_base/__init__.py index c2775780..70a80412 100644 --- a/src/layerlens/instrument/adapters/_base/__init__.py +++ b/src/layerlens/instrument/adapters/_base/__init__.py @@ -16,13 +16,27 @@ get_default_for, resilient_callback, ) +from .state_filters import ( + DEFAULT_PII_EXCLUDE_KEYS, + REDACTED_PLACEHOLDER, + StateFilter, + default_state_filter, + filter_payload_fields, + filter_state, +) __all__ = [ "AdapterInfo", "BaseAdapter", "DEFAULT_FAILURE_THRESHOLD", + "DEFAULT_PII_EXCLUDE_KEYS", "HealthStatus", + "REDACTED_PLACEHOLDER", "ResilienceTracker", + "StateFilter", + "default_state_filter", + "filter_payload_fields", + "filter_state", "get_default_for", "resilient_callback", ] diff --git a/src/layerlens/instrument/adapters/_base/state_filters.py b/src/layerlens/instrument/adapters/_base/state_filters.py new file mode 100644 index 00000000..72e35e8f --- /dev/null +++ b/src/layerlens/instrument/adapters/_base/state_filters.py @@ -0,0 +1,444 @@ +"""Per-key allowlist / denylist / mask filters for adapter state payloads. + +LangGraph's ``LangGraphStateAdapter`` (see +``src/layerlens/instrument/adapters/frameworks/langgraph/state.py`` on +``main`` — the mature reference implementation) supports include/exclude +key filters at the state-snapshot level so customers can scrub sensitive +state (api_keys, tokens, PII) WITHOUT modifying their agent code or +doing post-hoc redaction. + +Lighter multi-agent adapters (agno, openai_agents, llama_index, +google_adk, strands, pydantic_ai, ms_agent_framework) emit dict-shaped +state into ``agent.input`` / ``agent.output`` / ``agent.state.change`` +events. Without filtering, that state can carry credentials, PII, or +unbounded cardinality straight into telemetry sinks. + +This module provides: + +* :class:`StateFilter` — Pydantic-style config object capturing the + three filter operations (exclude, mask, include-allowlist). +* :func:`filter_state` — pure function that applies a filter to a dict + recursively. +* :data:`DEFAULT_PII_EXCLUDE_KEYS` — conservative default denylist that + matches common PII / credential field names (case-insensitive + substring match) so customers who forget to configure a filter still + get sensible protection. +* :func:`default_state_filter` — factory for the default PII-aware + filter installed by every framework adapter unless the customer + overrides it. + +The filter is intentionally cross-cutting: framework adapters expose a +``state_filter`` constructor parameter (defaulting to +``default_state_filter()``), keep it reachable via +``self._state_filter``, and pass dict-shaped payload fields through +:func:`filter_state` before they are emitted. Multi-tenancy is +preserved by applying the SAME default filter regardless of org — every +customer gets baseline PII protection out of the box. + +Auditability: :func:`filter_state` returns a 2-tuple of +``(filtered_dict, filtered_keys)`` so callers can record the names of +any keys that were excluded or masked. Adapters surface this list as +``_filtered_keys`` metadata on the emitted event so customers can see +exactly what was clipped from the payload. + +This module is **adapter-internal infrastructure**. It is NOT public +API for end users — there are no version guarantees on the helpers +exposed here, only on the BaseAdapter contract. +""" + +from __future__ import annotations + +import re +from typing import Any, Dict, List, Tuple, Iterable, Optional, FrozenSet +from dataclasses import field, dataclass + +# --------------------------------------------------------------------------- +# Public constants +# --------------------------------------------------------------------------- + + +REDACTED_PLACEHOLDER: str = "[REDACTED]" +"""String used to replace masked values. + +Picked as a recognisable, non-PII string that is short enough not to +inflate payload size and obvious enough that downstream operators +immediately understand the value was clipped on purpose. +""" + + +DEFAULT_PII_EXCLUDE_KEYS: FrozenSet[str] = frozenset( + { + # Credentials + "password", + "passwd", + "pwd", + "api_key", + "apikey", + "api_secret", + "secret", + "secret_key", + "access_token", + "refresh_token", + "auth_token", + "bearer_token", + "token", + "session_token", + "cookie", + "cookies", + "private_key", + "client_secret", + "service_account", + # Personal identifiers + "ssn", + "social_security", + "social_security_number", + "tax_id", + "national_id", + "passport", + "passport_number", + "drivers_license", + # Financial + "credit_card", + "credit_card_number", + "card_number", + "cvv", + "cvc", + "iban", + "account_number", + "routing_number", + # Contact / location + "email", + "email_address", + "phone", + "phone_number", + "address", + "street_address", + "home_address", + "billing_address", + "shipping_address", + # Authn material + "authorization", + "x-api-key", + "set-cookie", + } +) +"""Default exclude-key denylist. + +The check performed by :func:`filter_state` is **case-insensitive +substring** — so a key named ``"customer_email"`` matches the entry +``"email"`` and is filtered. This catches the long tail of vendor- or +team-specific field names (e.g. ``USER_API_KEY``, +``stripe_customer_email``, ``X-Api-Key``) without forcing the caller +to enumerate every variant. + +The list is conservative on purpose: false positives (filtering a +field that was not actually PII) are recoverable by the customer +re-emitting telemetry with a custom :class:`StateFilter`. False +negatives (a credential leaking into a sink) are not. +""" + + +# --------------------------------------------------------------------------- +# StateFilter dataclass +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class StateFilter: + """Declarative filter applied to dict-shaped adapter state payloads. + + Three operations, applied in this order during :func:`filter_state`: + + 1. **exclude_keys** — keys (case-insensitive substring match) + removed from the output entirely. + 2. **mask_keys** — keys (case-insensitive substring match) whose + values are replaced with :data:`REDACTED_PLACEHOLDER`. The key + remains visible (so dashboards see the field exists), but the + value is hidden. + 3. **include_keys** — if non-empty, the output is restricted to + only these keys (case-insensitive equality match against the + full key name; substring match would be too lossy when the + caller explicitly named an allowlist). + + Order matters: ``exclude`` runs first (cheapest, fully removes + keys), ``mask`` next (rewrites values), and ``include`` last so the + allowlist can still narrow the masked-but-allowed surface. + + Parameters + ---------- + include_keys: + If non-empty, ONLY keys equal (case-insensitive) to one of + these values are kept after exclude/mask. Default ``None`` + means "keep everything that survives exclude/mask". + exclude_keys: + Keys whose name *contains* any of these substrings (case- + insensitive) are removed. Defaults to + :data:`DEFAULT_PII_EXCLUDE_KEYS`. + mask_keys: + Keys whose name *contains* any of these substrings (case- + insensitive) have their value replaced with + :data:`REDACTED_PLACEHOLDER`. Default empty — opt-in. + recursive: + When ``True`` (default), nested dicts (and dicts inside lists) + are also filtered. When ``False``, only the top-level dict is + examined — useful when the caller knows nested structures are + already safe. + """ + + include_keys: Optional[FrozenSet[str]] = None + exclude_keys: FrozenSet[str] = field(default_factory=lambda: DEFAULT_PII_EXCLUDE_KEYS) + mask_keys: FrozenSet[str] = field(default_factory=frozenset) + recursive: bool = True + + def __post_init__(self) -> None: + # Normalize all key collections to lowercase frozensets for + # case-insensitive comparison. Frozen dataclass means we must + # reach through ``object.__setattr__`` to mutate. + if self.include_keys is not None: + object.__setattr__( + self, + "include_keys", + frozenset(k.lower() for k in self.include_keys), + ) + object.__setattr__( + self, + "exclude_keys", + frozenset(k.lower() for k in self.exclude_keys), + ) + object.__setattr__( + self, + "mask_keys", + frozenset(k.lower() for k in self.mask_keys), + ) + + # -- factory helpers -------------------------------------------------- + + @classmethod + def permissive(cls) -> "StateFilter": + """Filter that does NOTHING — for tests / explicit opt-out.""" + return cls(include_keys=None, exclude_keys=frozenset(), mask_keys=frozenset()) + + @classmethod + def with_extra_excludes(cls, extra: Iterable[str]) -> "StateFilter": + """Default PII filter PLUS the caller's additional excludes.""" + merged = frozenset(k.lower() for k in DEFAULT_PII_EXCLUDE_KEYS) | frozenset(k.lower() for k in extra) + return cls(exclude_keys=merged) + + # -- public introspection -------------------------------------------- + + def as_metadata(self) -> Dict[str, Any]: + """Snapshot of this filter for inclusion in adapter / replay metadata.""" + meta: Dict[str, Any] = { + "exclude_keys_count": len(self.exclude_keys), + "mask_keys_count": len(self.mask_keys), + "recursive": self.recursive, + } + if self.include_keys is not None: + meta["include_keys_count"] = len(self.include_keys) + # Allowlists are usually short and intentional — surface them + # so customers can verify exactly what they configured. + meta["include_keys"] = sorted(self.include_keys) + return meta + + +# --------------------------------------------------------------------------- +# Default factory +# --------------------------------------------------------------------------- + + +def default_state_filter() -> StateFilter: + """Return the conservative default filter installed by adapters. + + Excludes the built-in :data:`DEFAULT_PII_EXCLUDE_KEYS` denylist with + no additional masks or allowlist. Customers who do nothing still + get baseline PII protection on every emitted state payload. + """ + return StateFilter() + + +# --------------------------------------------------------------------------- +# Core filtering function +# --------------------------------------------------------------------------- + + +def filter_state( + state: Any, + filter: StateFilter, +) -> Tuple[Any, List[str]]: + """Apply *filter* to *state* and return (filtered_state, filtered_keys). + + ``state`` may be any value. Filtering is only applied when the value + (or, recursively, a nested value) is a ``dict``. Non-dict primitives + pass through unchanged. + + Returns + ------- + filtered_state: + The filtered value, with the same shape as the input but with + sensitive keys excluded or masked. + filtered_keys: + Sorted list of unique key names (lowercased) that were either + excluded or masked anywhere in the structure. Adapters surface + this list as ``_filtered_keys`` metadata so customers can see + what was clipped without exposing the values themselves. + + Notes + ----- + * The *filter* parameter is positional but named ``filter`` for + readability at call sites — even though it shadows the Python + builtin, the scope is local and there is no real ambiguity. + * Sets / tuples are NOT recursed into. Only ``dict`` and ``list`` + are walked. This matches the LangGraph reference implementation. + """ + filtered_keys_set: set[str] = set() + out = _filter_value(state, filter, filtered_keys_set) + return out, sorted(filtered_keys_set) + + +def _filter_value( + value: Any, + flt: StateFilter, + filtered_keys: set[str], +) -> Any: + """Recursive helper that mutates *filtered_keys* as a side effect.""" + if isinstance(value, dict): + return _filter_dict(value, flt, filtered_keys) + if flt.recursive and isinstance(value, list): + return [_filter_value(item, flt, filtered_keys) for item in value] + return value + + +def _filter_dict( + state: Dict[Any, Any], + flt: StateFilter, + filtered_keys: set[str], +) -> Dict[Any, Any]: + """Apply exclude → mask → include to a single dict.""" + out: Dict[Any, Any] = {} + for key, value in state.items(): + key_norm = str(key).lower() + + # 1. Exclude — drop the key entirely. + if _matches_substring(key_norm, flt.exclude_keys): + filtered_keys.add(key_norm) + continue + + # 2. Mask — keep the key, replace the value. + if _matches_substring(key_norm, flt.mask_keys): + filtered_keys.add(key_norm) + out[key] = REDACTED_PLACEHOLDER + continue + + # 3. Recurse for nested dicts / lists when requested. + if flt.recursive: + value = _filter_value(value, flt, filtered_keys) + + out[key] = value + + # 4. Include allowlist — applied AFTER exclude/mask so the allowlist + # narrows the surviving surface, never widens it. + if flt.include_keys is not None: + narrowed: Dict[Any, Any] = {} + for key, value in out.items(): + if str(key).lower() in flt.include_keys: + narrowed[key] = value + else: + filtered_keys.add(str(key).lower()) + return narrowed + + return out + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +_NON_WORD_RE = re.compile(r"[^a-z0-9]+") + + +def _normalize_key_for_match(key: str) -> str: + """Strip non-alphanumeric chars so e.g. ``X-Api-Key`` matches ``api_key``. + + Both the candidate key and the configured substring are normalised + via this function before substring containment is tested. This makes + the filter resilient to header-style ``-``, screaming snake case, + camelCase, and other delimiter variants. + """ + return _NON_WORD_RE.sub("", key.lower()) + + +def _matches_substring(key_norm: str, needles: FrozenSet[str]) -> bool: + """Return ``True`` when *key_norm* contains any of *needles*. + + Both sides are normalised to alphanumeric-only lowercase so the + comparison is robust across naming conventions. + """ + if not needles: + return False + candidate = _normalize_key_for_match(key_norm) + for needle in needles: + if _normalize_key_for_match(needle) in candidate: + return True + return False + + +# --------------------------------------------------------------------------- +# Convenience: filter only specified payload fields +# --------------------------------------------------------------------------- + + +def filter_payload_fields( + payload: Dict[str, Any], + flt: StateFilter, + fields: Iterable[str], +) -> List[str]: + """In-place filter that only touches *fields* of *payload*. + + Many adapter payloads mix safe scalar metadata (``model``, + ``latency_ms``, ``agent_name``) with potentially sensitive + dict-shaped state (``input``, ``output``, ``messages``, ``deps``). + Filtering the entire payload would rewrite the metadata too. This + helper applies :func:`filter_state` ONLY to the named fields when + they are dict-shaped (or list-of-dict-shaped) and leaves everything + else untouched. + + The function records every key that was clipped, attaches the + sorted list to ``payload['_filtered_keys']``, and returns the same + list so the caller can inspect it without re-reading the payload. + + Returns the list of filtered keys (possibly empty). + """ + all_filtered: set[str] = set() + for fname in fields: + if fname not in payload: + continue + original = payload[fname] + if not isinstance(original, (dict, list)): + continue + filtered, keys = filter_state(original, flt) + payload[fname] = filtered + all_filtered.update(keys) + + if all_filtered: + sorted_keys = sorted(all_filtered) + # Merge with any pre-existing _filtered_keys (the caller may have + # filtered another payload section earlier in the same emit). + existing = payload.get("_filtered_keys") + if isinstance(existing, list): + merged = sorted(set(existing) | set(sorted_keys)) + payload["_filtered_keys"] = merged + return merged + payload["_filtered_keys"] = sorted_keys + return sorted_keys + return [] + + +__all__ = [ + "DEFAULT_PII_EXCLUDE_KEYS", + "REDACTED_PLACEHOLDER", + "StateFilter", + "default_state_filter", + "filter_payload_fields", + "filter_state", +] diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py index 9de15a01..ece7151e 100644 --- a/src/layerlens/instrument/adapters/frameworks/_base_framework.py +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -12,7 +12,14 @@ import threading from typing import Any, Dict, Optional -from .._base import AdapterInfo, BaseAdapter, ResilienceTracker +from .._base import ( + AdapterInfo, + BaseAdapter, + StateFilter, + ResilienceTracker, + default_state_filter, + filter_payload_fields, +) from ..._context import ( RunState, _pop_span, @@ -48,7 +55,12 @@ def _check_dependency(self, available: bool) -> None: "Install it with: pip install layerlens[%s]" % (pkg, self.name, pkg) ) - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: self._client = client self._config = capture_config or CaptureConfig.standard() self._lock = threading.Lock() @@ -61,6 +73,16 @@ def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) # has set it (class-level), otherwise fall back to the class name. adapter_name = getattr(type(self), "name", None) or type(self).__name__ self._resilience: ResilienceTracker = ResilienceTracker(adapter_name) + # State filtering: every framework adapter gets a per-instance + # filter so dict-shaped state (agent input, output, messages, + # deps, args, etc.) is scrubbed of common PII / credentials + # before it reaches a telemetry sink. Customers who don't pass a + # ``state_filter`` get :func:`default_state_filter` which excludes + # the conservative :data:`DEFAULT_PII_EXCLUDE_KEYS` denylist — + # baseline protection out of the box per CLAUDE.md ("never silently + # leak customer data"). Customers who explicitly pass + # ``StateFilter.permissive()`` opt out. + self._state_filter: StateFilter = state_filter if state_filter is not None else default_state_filter() # Public, mypy-friendly alias of the failure counter — kept as a # property-shaped int so external callers can read it without # importing ResilienceTracker. @@ -235,6 +257,41 @@ def _set_if_capturing(self, payload: Dict[str, Any], key: str, value: Any) -> No if self._config.capture_content and value is not None: payload[key] = value + # ------------------------------------------------------------------ + # State filtering — applied to dict-shaped payload fields before emit + # ------------------------------------------------------------------ + + def _filter_payload(self, payload: Dict[str, Any], *fields: str) -> None: + """Filter dict-shaped *fields* of *payload* through the state filter. + + Drops or masks keys per ``self._state_filter`` and records any + clipped keys under ``payload['_filtered_keys']`` for audit + visibility (per the `auditable` contract called out in the + cross-pollination audit §2.12). + + Adapters call this immediately before ``self._emit(...)`` for + any payload that may contain user-controlled state (input, + output, messages, deps, args, state, context, etc.). + + No-op when *fields* are absent from the payload, when the + values are not dict / list shapes, or when the filter has no + rules to apply. + """ + if not fields: + return + filter_payload_fields(payload, self._state_filter, fields) + + def serialize_state_filter_for_replay(self) -> Dict[str, Any]: + """Snapshot of the active state filter for replay reproducibility. + + Replays must apply the SAME filter as the original run so the + captured payload shapes match. Adapters that implement a + ``serialize_for_replay`` method include this dict under a + ``state_filter`` key so the replay engine can reconstruct an + equivalent :class:`StateFilter` on the other side. + """ + return self._state_filter.as_metadata() + # ------------------------------------------------------------------ # Event emission # ------------------------------------------------------------------ @@ -328,8 +385,16 @@ def adapter_info(self) -> AdapterInfo: # Merge live resilience snapshot into the metadata block so # ``adapter_info().metadata['resilience_status']`` reports # HEALTHY / DEGRADED to monitoring code without each subclass - # having to remember to do it. - merged_metadata: Dict[str, Any] = {**self._metadata, **self._resilience.as_metadata()} + # having to remember to do it. Also surface the active state + # filter config under ``state_filter`` so operators can verify + # what's being scrubbed (or audit that the default PII + # protection wasn't accidentally disabled with + # ``StateFilter.permissive()``). + merged_metadata: Dict[str, Any] = { + **self._metadata, + **self._resilience.as_metadata(), + "state_filter": self._state_filter.as_metadata(), + } return AdapterInfo( name=self.name, adapter_type="framework", diff --git a/src/layerlens/instrument/adapters/frameworks/agno.py b/src/layerlens/instrument/adapters/frameworks/agno.py index ba10fcd2..ca918f85 100644 --- a/src/layerlens/instrument/adapters/frameworks/agno.py +++ b/src/layerlens/instrument/adapters/frameworks/agno.py @@ -3,7 +3,7 @@ import logging from typing import Any, Dict, List, Optional -from .._base import resilient_callback +from .._base import StateFilter, resilient_callback from ._utils import safe_serialize from ._base_framework import FrameworkAdapter from ..._capture_config import CaptureConfig @@ -140,8 +140,13 @@ class AgnoAdapter(FrameworkAdapter): name = "agno" - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - super().__init__(client, capture_config) + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: + super().__init__(client, capture_config, state_filter=state_filter) self._originals: Dict[int, Dict[str, Any]] = {} self._wrapped_agents: List[Any] = [] @@ -255,6 +260,7 @@ def _on_run_start(self, agent: Any, input_data: Any) -> None: if model: payload["model"] = model self._set_if_capturing(payload, "input", safe_serialize(input_data)) + self._filter_payload(payload, "input") self._emit("agent.input", payload, span_id=root, parent_span_id=None, span_name=f"agno:{name}") @resilient_callback(callback_name="_on_run_end") @@ -284,6 +290,7 @@ def _emit_output(self, agent: Any, result: Any, error: Optional[Exception]) -> N payload["error"] = str(error) payload["error_type"] = type(error).__name__ self._set_if_capturing(payload, "output", safe_serialize(output)) + self._filter_payload(payload, "output") self._emit("agent.output", payload, span_id=root, parent_span_id=None, span_name=f"agno:{name}") def _emit_model(self, agent: Any, result: Any) -> None: @@ -310,12 +317,14 @@ def _emit_tools(self, result: Any) -> None: call_payload = self._payload(tool_name=tool["tool_name"]) self._set_if_capturing(call_payload, "input", safe_serialize(tool.get("tool_args"))) + self._filter_payload(call_payload, "input") self._emit("tool.call", call_payload, span_id=span_id, parent_span_id=root) result_payload = self._payload(tool_name=tool["tool_name"]) self._set_if_capturing(result_payload, "output", safe_serialize(tool.get("result"))) if tool.get("latency_ms") is not None: result_payload["latency_ms"] = tool["latency_ms"] + self._filter_payload(result_payload, "output") self._emit("tool.result", result_payload, span_id=span_id, parent_span_id=root) diff --git a/src/layerlens/instrument/adapters/frameworks/google_adk.py b/src/layerlens/instrument/adapters/frameworks/google_adk.py index 396b4380..6b4e755b 100644 --- a/src/layerlens/instrument/adapters/frameworks/google_adk.py +++ b/src/layerlens/instrument/adapters/frameworks/google_adk.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, Optional -from .._base import resilient_callback +from .._base import StateFilter, resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -53,8 +53,13 @@ class GoogleADKAdapter(FrameworkAdapter): name = "google_adk" - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - super().__init__(client, capture_config) + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: + super().__init__(client, capture_config, state_filter=state_filter) self._collector: Optional[TraceCollector] = None self._run_span_id: Optional[str] = None self._agent_span_ids: Dict[str, str] = {} @@ -176,6 +181,9 @@ def _on_before_run(self, invocation_context: Any) -> None: user_content = getattr(invocation_context, "user_content", None) self._set_if_capturing(payload, "input", safe_serialize(user_content)) + # User content + agent_tree (which can contain sub-agent + # config) are user-controlled; filter PII keys. + self._filter_payload(payload, "input", "agent_tree") self._fire("agent.input", payload, span_id=span_id, span_name=agent_name) @resilient_callback(callback_name="_on_after_run") @@ -208,6 +216,7 @@ def _on_before_agent(self, agent: Any, callback_context: Any) -> None: payload = self._payload(agent_name=name) user_content = getattr(callback_context, "user_content", None) self._set_if_capturing(payload, "input", safe_serialize(user_content)) + self._filter_payload(payload, "input") self._fire("agent.input", payload, span_id=span_id, parent_span_id=self._run_span_id, span_name=f"agent:{name}") @resilient_callback(callback_name="_on_after_agent") @@ -307,10 +316,14 @@ def _on_after_tool(self, tool: Any, tool_args: Any, tool_context: Any, result: A self._set_if_capturing(call_payload, "input", safe_serialize(tool_args)) if latency_ms is not None: call_payload["latency_ms"] = latency_ms + # Tool args + tool result are user-controlled; both are + # high-frequency credential leak vectors. + self._filter_payload(call_payload, "input") self._fire("tool.call", call_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}") result_payload = self._payload(tool_name=tool_name) self._set_if_capturing(result_payload, "output", safe_serialize(result)) + self._filter_payload(result_payload, "output") self._fire("tool.result", result_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}") @resilient_callback(callback_name="_on_tool_error") diff --git a/src/layerlens/instrument/adapters/frameworks/llamaindex.py b/src/layerlens/instrument/adapters/frameworks/llamaindex.py index 008e02a4..40d3d893 100644 --- a/src/layerlens/instrument/adapters/frameworks/llamaindex.py +++ b/src/layerlens/instrument/adapters/frameworks/llamaindex.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, List, Optional -from .._base import resilient_callback +from .._base import StateFilter, resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -69,8 +69,13 @@ class LlamaIndexAdapter(FrameworkAdapter): "ReRankEndEvent": "_on_rerank_end", } - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - super().__init__(client, capture_config) + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: + super().__init__(client, capture_config, state_filter=state_filter) self._span_handler: Optional[Any] = None self._event_handler: Optional[Any] = None # Per-root-span collectors (concurrent query support) @@ -243,6 +248,11 @@ def _on_llm_chat_end(self, event: Any) -> None: if output: payload["output_message"] = output + # Scrub PII / credentials from messages before they hit the + # event sink. ``output_message`` is a string scalar in this + # adapter, so the filter is a no-op for it; we still pass it + # through so any future structured-output change is covered. + self._filter_payload(payload, "messages", "output_message") self._fire("model.invoke", payload, span_id=span_id) if tokens: @@ -288,6 +298,7 @@ def _on_llm_completion_end(self, event: Any) -> None: if text: payload["output_message"] = str(text) + self._filter_payload(payload, "messages", "output_message") self._fire("model.invoke", payload, span_id=span_id) if tokens: @@ -317,6 +328,8 @@ def _on_tool_call(self, event: Any) -> None: if desc: payload["tool_description"] = str(desc) + # Tool input arguments are user-supplied; filter PII keys. + self._filter_payload(payload, "input") self._fire("tool.call", payload, span_id=span_id) # ------------------------------------------------------------------ @@ -331,6 +344,7 @@ def _on_retrieval_start(self, event: Any) -> None: query = getattr(event, "str_or_query_bundle", None) if query is not None: payload["input"] = str(query) + self._filter_payload(payload, "input") self._fire("tool.call", payload, span_id=span_id, span_name="retrieval") @resilient_callback(callback_name="_on_retrieval_end") @@ -342,6 +356,9 @@ def _on_retrieval_end(self, event: Any) -> None: payload["num_results"] = len(nodes) if self._config.capture_content: payload["output"] = _serialize_nodes(nodes) + # Retrieval results carry the node text — same PII risk as + # any document store payload. Filter dict-shaped node entries. + self._filter_payload(payload, "output") self._fire("tool.result", payload, span_id=span_id, span_name="retrieval") # ------------------------------------------------------------------ @@ -405,6 +422,7 @@ def _on_query_start(self, event: Any) -> None: query = getattr(event, "query", None) if query is not None: payload["input"] = str(query) + self._filter_payload(payload, "input") self._fire("agent.input", payload, span_id=span_id, span_name="query") @resilient_callback(callback_name="_on_query_end") @@ -415,6 +433,7 @@ def _on_query_end(self, event: Any) -> None: response = getattr(event, "response", None) if response is not None: payload["output"] = str(response) + self._filter_payload(payload, "output") self._fire("agent.output", payload, span_id=span_id, span_name="query") # ------------------------------------------------------------------ @@ -432,6 +451,7 @@ def _on_agent_step_start(self, event: Any) -> None: step_input = getattr(event, "input", None) if step_input is not None: payload["input"] = safe_serialize(step_input) + self._filter_payload(payload, "input") self._fire("agent.input", payload, span_id=span_id, span_name="agent_step") @resilient_callback(callback_name="_on_agent_step_end") @@ -442,6 +462,7 @@ def _on_agent_step_end(self, event: Any) -> None: output = getattr(event, "step_output", None) if output is not None: payload["output"] = safe_serialize(output) + self._filter_payload(payload, "output") self._fire("agent.output", payload, span_id=span_id, span_name="agent_step") # ------------------------------------------------------------------ diff --git a/src/layerlens/instrument/adapters/frameworks/openai_agents.py b/src/layerlens/instrument/adapters/frameworks/openai_agents.py index a354824c..e0f10a5a 100644 --- a/src/layerlens/instrument/adapters/frameworks/openai_agents.py +++ b/src/layerlens/instrument/adapters/frameworks/openai_agents.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional from datetime import datetime -from .._base import resilient_callback +from .._base import StateFilter, resilient_callback from ._utils import safe_serialize from ..._context import RunState, _current_run, _current_collector from ..._collector import TraceCollector @@ -58,8 +58,13 @@ class OpenAIAgentsAdapter(*_Bases): "response": "_handle_response_span", } - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - FrameworkAdapter.__init__(self, client, capture_config) + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: + FrameworkAdapter.__init__(self, client, capture_config, state_filter=state_filter) # trace_id -> RunState for concurrent trace isolation self._trace_runs: Dict[str, Any] = {} @@ -146,7 +151,10 @@ def _handle_agent_span(self, span: Any) -> None: val = getattr(data, key, None) if val: input_payload[key] = val - + # Tools/handoffs/output_type may carry user-controlled + # configuration (e.g. tool function args schemas referencing + # secrets); run them through the filter before emit. + self._filter_payload(input_payload, "tools", "handoffs", "output_type") self._emit( "agent.input", input_payload, @@ -194,6 +202,12 @@ def _handle_generation_span(self, span: Any) -> None: self._set_if_capturing(payload, "messages", safe_serialize(getattr(data, "input", None))) self._set_if_capturing(payload, "output_message", safe_serialize(getattr(data, "output", None))) + # Messages and output frequently contain prompt text — scrub + # any structured PII / credential fields before they reach a + # sink. ``filter_payload_fields`` is a no-op when these fields + # are absent (capture_content disabled) or scalar (already a + # string). + self._filter_payload(payload, "messages", "output_message", "model_config") if span.error: payload["error"] = safe_serialize(span.error) @@ -248,6 +262,10 @@ def _handle_function_span(self, span: Any) -> None: ) if resource_ref: call_payload["mcp_resource_uri"] = str(resource_ref) + # Tool input args + parameters_schema + mcp_data are user-defined + # and are the most common vector for credential leakage in + # agent traces. + self._filter_payload(call_payload, "input", "parameters_schema", "mcp_data") self._emit("tool.call", call_payload, span_id=span_id, parent_span_id=parent_id) # Emit tool.result or agent.error @@ -262,6 +280,7 @@ def _handle_function_span(self, span: Any) -> None: self._set_if_capturing(result_payload, "output", safe_serialize(getattr(data, "output", None))) if duration_ms is not None: result_payload["latency_ms"] = duration_ms + self._filter_payload(result_payload, "output") self._emit("tool.result", result_payload, span_id=span_id, parent_span_id=parent_id) def _handle_handoff_span(self, span: Any) -> None: diff --git a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py index a97913eb..5055b02b 100644 --- a/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py +++ b/src/layerlens/instrument/adapters/frameworks/pydantic_ai.py @@ -3,7 +3,7 @@ import logging from typing import Any, Dict, Optional -from .._base import resilient_callback +from .._base import StateFilter, resilient_callback from ._utils import safe_serialize from ._base_framework import FrameworkAdapter from ..._capture_config import CaptureConfig @@ -40,8 +40,13 @@ class PydanticAIAdapter(FrameworkAdapter): name = "pydantic-ai" - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - super().__init__(client, capture_config) + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: + super().__init__(client, capture_config, state_filter=state_filter) self._target: Any = None self._hooks: Any = None @@ -134,6 +139,12 @@ def _on_before_run(self, ctx: Any) -> None: safe_serialize(deps)[:500] if isinstance(safe_serialize(deps), str) else _summarize_deps(deps) ) + # Filter agent input + deps_summary before emit. ``deps`` in + # PydanticAI is the canonical request-scoped object — DB + # handles, request_id, user objects all live here. The filter + # is the last line of defence between the deps shape and the + # event sink. + self._filter_payload(payload, "input", "deps_summary") self._emit( "agent.input", payload, @@ -160,6 +171,7 @@ def _on_after_run(self, ctx: Any, *, result: Any) -> Any: payload["latency_ms"] = latency_ms self._set_if_capturing(payload, "output", output) payload.update(usage) + self._filter_payload(payload, "output") self._emit( "agent.output", payload, @@ -243,6 +255,7 @@ def _on_after_model_request( "input", safe_serialize(getattr(part, "args", None)), ) + self._filter_payload(tool_payload, "input") self._emit("tool.call", tool_payload) return response @@ -317,6 +330,7 @@ def _on_after_tool_execute( self._set_if_capturing(payload, "output", safe_serialize(result)) if latency_ms is not None: payload["latency_ms"] = latency_ms + self._filter_payload(payload, "output") self._emit("tool.result", payload, span_id=span_id) return result @@ -389,6 +403,7 @@ def _on_after_stream(self, ctx: Any, *, response: Any = None, **_kwargs: Any) -> output = self._extract_output(response) if output is not None: payload["output_message"] = output + self._filter_payload(payload, "output_message") self._emit("model.invoke", payload) # ------------------------------------------------------------------ diff --git a/src/layerlens/instrument/adapters/frameworks/strands.py b/src/layerlens/instrument/adapters/frameworks/strands.py index 60168f4c..6888b23f 100644 --- a/src/layerlens/instrument/adapters/frameworks/strands.py +++ b/src/layerlens/instrument/adapters/frameworks/strands.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, Optional -from .._base import resilient_callback +from .._base import StateFilter, resilient_callback from ._utils import safe_serialize from ..._collector import TraceCollector from ._base_framework import FrameworkAdapter @@ -53,8 +53,13 @@ class StrandsAdapter(FrameworkAdapter): name = "strands" - def __init__(self, client: Any, capture_config: Optional[CaptureConfig] = None) -> None: - super().__init__(client, capture_config) + def __init__( + self, + client: Any, + capture_config: Optional[CaptureConfig] = None, + state_filter: Optional[StateFilter] = None, + ) -> None: + super().__init__(client, capture_config, state_filter=state_filter) self._collector: Optional[TraceCollector] = None self._run_span_id: Optional[str] = None self._current_agent_name: Optional[str] = None @@ -203,6 +208,7 @@ def _on_before_invocation(self, event: Any) -> None: messages = getattr(event, "messages", None) self._set_if_capturing(payload, "input", safe_serialize(messages)) + self._filter_payload(payload, "input") self._fire("agent.input", payload, span_id=span_id, span_name=name) @resilient_callback(callback_name="_on_after_invocation") @@ -230,6 +236,7 @@ def _on_after_invocation(self, event: Any) -> None: # so we read per-cycle tokens here instead. self._emit_per_cycle_tokens(agent) + self._filter_payload(payload, "output") self._fire("agent.output", payload, span_id=span_id, span_name=name) self._end_trace() @@ -310,6 +317,7 @@ def _on_after_tool(self, event: Any) -> None: self._set_if_capturing(call_payload, "input", safe_serialize(tool_input)) if latency_ms is not None: call_payload["latency_ms"] = latency_ms + self._filter_payload(call_payload, "input") self._fire("tool.call", call_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}") result = getattr(event, "result", None) @@ -326,6 +334,7 @@ def _on_after_tool(self, event: Any) -> None: result_payload["error"] = str(exception) result_payload["error_type"] = type(exception).__name__ + self._filter_payload(result_payload, "output") self._fire( "tool.result", result_payload, span_id=span_id, parent_span_id=parent, span_name=f"tool:{tool_name}" ) diff --git a/tests/instrument/adapters/_base/test_state_filters.py b/tests/instrument/adapters/_base/test_state_filters.py new file mode 100644 index 00000000..ebb63b8d --- /dev/null +++ b/tests/instrument/adapters/_base/test_state_filters.py @@ -0,0 +1,429 @@ +"""Tests for the per-key allowlist / denylist / mask filter for adapter state. + +Covers :class:`StateFilter`, :func:`filter_state`, :func:`filter_payload_fields`, +:data:`DEFAULT_PII_EXCLUDE_KEYS`, and the ``_filter_payload`` / +``serialize_state_filter_for_replay`` integration on every multi-agent +framework adapter. + +Every test asserts behaviour that prevents PII / credentials from leaving +the adapter — the filter is the last line of defence between user state +and a telemetry sink, so failure modes here are CRITICAL severity per +CLAUDE.md ("never silently leak customer data"). +""" + +from __future__ import annotations + +from typing import Any, Dict, List +from unittest.mock import Mock + +import pytest + +from layerlens.instrument.adapters._base import ( + REDACTED_PLACEHOLDER, + DEFAULT_PII_EXCLUDE_KEYS, + StateFilter, + filter_state, + default_state_filter, + filter_payload_fields, +) +from layerlens.instrument.adapters.frameworks._base_framework import FrameworkAdapter + +# --------------------------------------------------------------------------- +# StateFilter dataclass behaviour +# --------------------------------------------------------------------------- + + +class TestStateFilterConstruction: + """The filter normalises its inputs so callers can pass any iterable shape.""" + + def test_default_is_pii_aware(self) -> None: + """A ``StateFilter()`` with no args keeps the conservative default + denylist — the "I forgot to configure" path must still scrub + common PII out of the box. + """ + f = StateFilter() + # The constructor lowercases the frozenset; assert membership + # rather than identity. + assert "password" in f.exclude_keys + assert "api_key" in f.exclude_keys + assert "ssn" in f.exclude_keys + assert f.include_keys is None + assert f.mask_keys == frozenset() + + def test_lowercases_exclude_keys(self) -> None: + f = StateFilter(exclude_keys=frozenset({"PASSWORD", "Authorization"})) + assert "password" in f.exclude_keys + assert "authorization" in f.exclude_keys + + def test_lowercases_include_keys(self) -> None: + f = StateFilter(include_keys=frozenset({"User_ID", "MODEL"})) + assert f.include_keys is not None + assert "user_id" in f.include_keys + assert "model" in f.include_keys + + def test_permissive_factory(self) -> None: + """``permissive()`` removes all rules — used in tests / explicit opt-out.""" + f = StateFilter.permissive() + assert f.exclude_keys == frozenset() + assert f.mask_keys == frozenset() + assert f.include_keys is None + + def test_with_extra_excludes_factory(self) -> None: + f = StateFilter.with_extra_excludes(["custom_secret", "internal_id"]) + assert "password" in f.exclude_keys # default still present + assert "custom_secret" in f.exclude_keys + assert "internal_id" in f.exclude_keys + + +class TestStateFilterMetadata: + """``as_metadata`` produces a stable, dashboard-safe snapshot.""" + + def test_default_metadata_shape(self) -> None: + meta = StateFilter().as_metadata() + assert meta["exclude_keys_count"] == len(DEFAULT_PII_EXCLUDE_KEYS) + assert meta["mask_keys_count"] == 0 + assert meta["recursive"] is True + # Allowlist is None → key not surfaced. + assert "include_keys" not in meta + + def test_allowlist_surfaces_in_metadata(self) -> None: + f = StateFilter(include_keys=frozenset({"foo", "bar"})) + meta = f.as_metadata() + # Allowlists are intentionally short — surface them so customers + # can verify exactly what they configured. + assert meta["include_keys"] == ["bar", "foo"] + assert meta["include_keys_count"] == 2 + + +# --------------------------------------------------------------------------- +# filter_state — exclude / mask / include precedence +# --------------------------------------------------------------------------- + + +class TestFilterStateExclude: + def test_excludes_default_pii_keys(self) -> None: + state = {"username": "alice", "password": "hunter2", "api_key": "sk-..."} + out, keys = filter_state(state, default_state_filter()) + assert "username" in out + assert "password" not in out + assert "api_key" not in out + # filtered_keys reports BOTH excluded names — auditable trail. + assert sorted(keys) == ["api_key", "password"] + + def test_substring_match_catches_vendor_variants(self) -> None: + """``X-Api-Key``, ``stripe_customer_email``, ``USER_API_KEY`` should all match.""" + state = { + "X-Api-Key": "sk-secret", + "stripe_customer_email": "alice@example.com", + "USER_API_KEY": "user-secret", + "model": "gpt-5", + } + out, keys = filter_state(state, default_state_filter()) + assert "model" in out + assert "X-Api-Key" not in out + assert "stripe_customer_email" not in out + assert "USER_API_KEY" not in out + + def test_excludes_nothing_when_disabled(self) -> None: + state = {"password": "hunter2", "api_key": "sk-..."} + out, keys = filter_state(state, StateFilter.permissive()) + assert out == state + assert keys == [] + + +class TestFilterStateMask: + def test_masks_keys_keeps_field_visible(self) -> None: + f = StateFilter(exclude_keys=frozenset(), mask_keys=frozenset({"phone"})) + state = {"name": "Alice", "phone": "555-1234"} + out, keys = filter_state(state, f) + # Key remains so dashboards see the field exists, value is REDACTED. + assert out == {"name": "Alice", "phone": REDACTED_PLACEHOLDER} + assert keys == ["phone"] + + def test_mask_runs_before_recurse(self) -> None: + """A masked key's nested structure is NOT walked — the value + is replaced wholesale so nested PII can't leak through. + """ + f = StateFilter(exclude_keys=frozenset(), mask_keys=frozenset({"profile"})) + state = { + "profile": {"email": "alice@example.com", "phone": "555-1234"}, + } + out, _ = filter_state(state, f) + assert out == {"profile": REDACTED_PLACEHOLDER} + + +class TestFilterStateInclude: + def test_include_acts_as_allowlist(self) -> None: + f = StateFilter( + exclude_keys=frozenset(), + include_keys=frozenset({"model", "tokens_total"}), + ) + state = {"model": "gpt-5", "tokens_total": 100, "input": "secret prompt"} + out, keys = filter_state(state, f) + assert out == {"model": "gpt-5", "tokens_total": 100} + assert "input" in keys + + def test_include_runs_after_exclude(self) -> None: + """Even an allowlisted key is still removed if it matches exclude.""" + f = StateFilter( + include_keys=frozenset({"password", "model"}), # allow password + exclude_keys=frozenset({"password"}), # but also exclude it + ) + state = {"password": "hunter2", "model": "gpt-5"} + out, _ = filter_state(state, f) + assert "password" not in out # exclude wins + assert out == {"model": "gpt-5"} + + +class TestFilterStateRecursive: + def test_recurses_into_nested_dicts(self) -> None: + state = { + "user": {"name": "Alice", "password": "hunter2"}, + "model": "gpt-5", + } + out, keys = filter_state(state, default_state_filter()) + assert out == {"user": {"name": "Alice"}, "model": "gpt-5"} + assert "password" in keys + + def test_recurses_into_lists_of_dicts(self) -> None: + state = { + "messages": [ + {"role": "user", "content": "hi", "api_key": "sk-..."}, + {"role": "assistant", "content": "hello"}, + ], + } + out, keys = filter_state(state, default_state_filter()) + assert out["messages"][0] == {"role": "user", "content": "hi"} + assert out["messages"][1] == {"role": "assistant", "content": "hello"} + assert "api_key" in keys + + def test_non_recursive_skips_nested(self) -> None: + f = StateFilter(recursive=False) + state = {"user": {"password": "hunter2"}} + out, _ = filter_state(state, f) + # Top-level only — nested password survives because recursion off. + assert out == {"user": {"password": "hunter2"}} + + +class TestFilterStatePassthrough: + """Non-dict / non-list inputs pass through unchanged.""" + + @pytest.mark.parametrize("value", [None, 0, 1.5, "hello", True, b"bytes"]) + def test_primitives_pass_through(self, value: Any) -> None: + out, keys = filter_state(value, default_state_filter()) + assert out == value + assert keys == [] + + def test_empty_dict_returns_empty(self) -> None: + out, keys = filter_state({}, default_state_filter()) + assert out == {} + assert keys == [] + + +# --------------------------------------------------------------------------- +# filter_payload_fields — surgical filter for adapter use +# --------------------------------------------------------------------------- + + +class TestFilterPayloadFields: + def test_only_named_fields_are_filtered(self) -> None: + """Scalar metadata (model, latency_ms) is left alone; only the + named dict-shaped fields are scrubbed. + """ + payload: Dict[str, Any] = { + "model": "gpt-5", + "latency_ms": 42, + "input": {"user": "alice", "password": "hunter2"}, + } + clipped = filter_payload_fields(payload, default_state_filter(), ["input"]) + assert payload["model"] == "gpt-5" # untouched + assert payload["latency_ms"] == 42 # untouched + assert payload["input"] == {"user": "alice"} # filtered + assert clipped == ["password"] + assert payload["_filtered_keys"] == ["password"] + + def test_missing_fields_are_skipped(self) -> None: + payload: Dict[str, Any] = {"model": "gpt-5"} + clipped = filter_payload_fields(payload, default_state_filter(), ["input", "deps"]) + assert clipped == [] + assert "_filtered_keys" not in payload + + def test_scalar_field_is_not_filtered(self) -> None: + """An ``input`` field that's already a string passes through.""" + payload: Dict[str, Any] = {"input": "hello world"} + clipped = filter_payload_fields(payload, default_state_filter(), ["input"]) + assert payload["input"] == "hello world" + assert clipped == [] + + def test_merges_with_existing_filtered_keys(self) -> None: + """Multiple filter passes accumulate the filtered-key list.""" + payload: Dict[str, Any] = { + "_filtered_keys": ["password"], + "output": {"name": "alice", "ssn": "111-22-3333"}, + } + filter_payload_fields(payload, default_state_filter(), ["output"]) + # Both old + new are surfaced, sorted, deduped. + assert payload["_filtered_keys"] == ["password", "ssn"] + + +# --------------------------------------------------------------------------- +# FrameworkAdapter integration — every adapter must wire the filter +# --------------------------------------------------------------------------- + + +class _StubAdapter(FrameworkAdapter): + """Minimal concrete adapter so we can exercise base wiring.""" + + name = "stub" + + def _on_connect(self, target: Any = None, **kwargs: Any) -> None: + pass + + def _on_disconnect(self) -> None: + pass + + +class TestFrameworkAdapterStateFilterDefaults: + def test_default_filter_installed_on_construction(self) -> None: + a = _StubAdapter(client=Mock()) + # Default filter excludes the PII denylist — verify by snapshot. + meta = a._state_filter.as_metadata() + assert meta["exclude_keys_count"] == len(DEFAULT_PII_EXCLUDE_KEYS) + + def test_custom_filter_overrides_default(self) -> None: + custom = StateFilter.permissive() + a = _StubAdapter(client=Mock(), state_filter=custom) + assert a._state_filter is custom + + def test_filter_payload_drops_pii(self) -> None: + a = _StubAdapter(client=Mock()) + payload: Dict[str, Any] = { + "model": "gpt-5", + "input": {"user": "alice", "api_key": "sk-secret"}, + } + a._filter_payload(payload, "input") + assert payload["input"] == {"user": "alice"} + assert "api_key" in payload["_filtered_keys"] + + def test_serialize_state_filter_for_replay(self) -> None: + """Replay must capture the filter so the replay engine can + reconstruct an equivalent filter on the other side. + """ + a = _StubAdapter(client=Mock()) + snap = a.serialize_state_filter_for_replay() + assert snap["recursive"] is True + assert snap["exclude_keys_count"] == len(DEFAULT_PII_EXCLUDE_KEYS) + + def test_state_filter_appears_in_adapter_info(self) -> None: + """``adapter_info().metadata['state_filter']`` lets operators + verify what's being scrubbed (and detect accidental + ``StateFilter.permissive()``). + """ + a = _StubAdapter(client=Mock()) + info = a.adapter_info() + assert "state_filter" in info.metadata + assert info.metadata["state_filter"]["exclude_keys_count"] == len(DEFAULT_PII_EXCLUDE_KEYS) + + +# --------------------------------------------------------------------------- +# Per-adapter constructor wiring — the 6 multi-agent adapters present on +# this base. ms_agent_framework is enumerated in the audit but doesn't +# exist on this branch's history; will be wired when its adapter lands +# on `feat/instrument-callback-resilience` (or its successor). +# --------------------------------------------------------------------------- + + +_PARAM_ADAPTERS: List[Any] = [] +try: + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + + _PARAM_ADAPTERS.append(("agno", AgnoAdapter)) +except Exception: + pass +try: + from layerlens.instrument.adapters.frameworks.openai_agents import OpenAIAgentsAdapter + + _PARAM_ADAPTERS.append(("openai_agents", OpenAIAgentsAdapter)) +except Exception: + pass +try: + from layerlens.instrument.adapters.frameworks.llamaindex import LlamaIndexAdapter + + _PARAM_ADAPTERS.append(("llamaindex", LlamaIndexAdapter)) +except Exception: + pass +try: + from layerlens.instrument.adapters.frameworks.google_adk import GoogleADKAdapter + + _PARAM_ADAPTERS.append(("google_adk", GoogleADKAdapter)) +except Exception: + pass +try: + from layerlens.instrument.adapters.frameworks.strands import StrandsAdapter + + _PARAM_ADAPTERS.append(("strands", StrandsAdapter)) +except Exception: + pass +try: + from layerlens.instrument.adapters.frameworks.pydantic_ai import PydanticAIAdapter + + _PARAM_ADAPTERS.append(("pydantic_ai", PydanticAIAdapter)) +except Exception: + pass + + +@pytest.mark.parametrize(("name", "cls"), _PARAM_ADAPTERS) +class TestPerAdapterStateFilterWiring: + """Verify every multi-agent adapter accepts ``state_filter`` and wires it correctly.""" + + def test_constructor_accepts_state_filter(self, name: str, cls: Any) -> None: + custom = StateFilter.permissive() + adapter = cls(client=Mock(), state_filter=custom) + assert adapter._state_filter is custom + + def test_default_state_filter_is_pii_aware(self, name: str, cls: Any) -> None: + adapter = cls(client=Mock()) + # Out of the box: every adapter excludes the PII denylist. + assert "password" in adapter._state_filter.exclude_keys + assert "api_key" in adapter._state_filter.exclude_keys + + def test_state_filter_surfaces_in_adapter_info(self, name: str, cls: Any) -> None: + adapter = cls(client=Mock()) + info = adapter.adapter_info() + assert "state_filter" in info.metadata + + +# --------------------------------------------------------------------------- +# End-to-end: filter applied at the adapter's emit boundary +# --------------------------------------------------------------------------- + + +class TestEndToEndAgnoFilter: + """Use the agno adapter (which doesn't require optional deps for the + pure ``_filter_payload`` path) to demonstrate that the filter actually + runs at the emit boundary, not just sits idle on the adapter. + """ + + def test_filter_payload_emits_filtered_keys_metadata(self) -> None: + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + + adapter = AgnoAdapter(client=Mock()) + payload: Dict[str, Any] = { + "agent_name": "demo", + "input": {"prompt": "hi", "api_key": "sk-secret"}, + } + adapter._filter_payload(payload, "input") + assert payload["input"] == {"prompt": "hi"} + assert payload["_filtered_keys"] == ["api_key"] + + def test_filter_payload_with_permissive_filter_is_noop(self) -> None: + from layerlens.instrument.adapters.frameworks.agno import AgnoAdapter + + adapter = AgnoAdapter(client=Mock(), state_filter=StateFilter.permissive()) + payload: Dict[str, Any] = { + "input": {"prompt": "hi", "api_key": "sk-secret"}, + } + adapter._filter_payload(payload, "input") + # Permissive filter touches nothing. + assert payload["input"] == {"prompt": "hi", "api_key": "sk-secret"} + assert "_filtered_keys" not in payload