diff --git a/docs/adapters/frameworks-agentforce.md b/docs/adapters/frameworks-agentforce.md new file mode 100644 index 00000000..db47aa42 --- /dev/null +++ b/docs/adapters/frameworks-agentforce.md @@ -0,0 +1,133 @@ +# Salesforce Agentforce framework adapter + +`layerlens.instrument.adapters.frameworks.agentforce.AgentForceAdapter` +imports Salesforce Agentforce session traces from Data Cloud DMOs and emits +them as LayerLens events. + +This adapter is **import-mode** rather than runtime-instrumentation: it +authenticates against a Salesforce org via OAuth 2.0 JWT Bearer and runs +SOQL queries against the AgentForce DMO objects to backfill trace data. + +## Install + +```bash +pip install 'layerlens[agentforce]' +``` + +Pulls `requests>=2.28`. The Salesforce credentials must be provisioned +out-of-band (Connected App + private key + permitted user). + +## Quick start + +```python +from layerlens.instrument.adapters.frameworks.agentforce import ( + AgentForceAdapter, + SalesforceCredentials, +) +from layerlens.instrument.transport.sink_http import HttpEventSink + +credentials = SalesforceCredentials( + client_id="3MVG9...", + username="agent-importer@example.com", + private_key="env:SALESFORCE_PRIVATE_KEY", # or file path or raw PEM + instance_url="https://example.my.salesforce.com", +) + +sink = HttpEventSink(adapter_name="salesforce_agentforce") +adapter = AgentForceAdapter(credentials=credentials) +adapter.add_sink(sink) +adapter.connect() # JWT flow runs here + +result = adapter.import_sessions( + start_date="2026-04-01", + end_date="2026-04-25", + limit=100, +) +print(f"Imported {result.events_generated} events from {result.sessions_imported} sessions") + +adapter.disconnect() +sink.close() +``` + +## What's wrapped + +This adapter does not monkey-patch anything. It calls SOQL against: + +- `AIAgentSession` — top-level session record +- `AIAgentSessionParticipant` — agents + users in the session +- `AIAgentInteraction` — turns within the session +- `AIAgentInteractionStep` — individual steps inside an interaction +- `AIAgentInteractionMessage` — raw input/output messages + +Each row is normalized via `AgentForceNormalizer` and emitted through the +adapter's `emit_dict_event` pipeline. + +Companion modules: + +- `AgentApiClient` — direct REST client for Agent API (real-time capture) +- `PlatformEventSubscriber` — gRPC Pub/Sub subscriber for near-real-time +- `TrustLayerImporter` — imports Einstein Trust Layer policies +- `EinsteinEvaluator` — runs LLM evaluation scenarios + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `agent.input` | L1 | Per `AIAgentInteractionMessage` with role=user. | +| `agent.output` | L1 | Per `AIAgentInteractionMessage` with role=agent. | +| `agent.action` | L4a | Per `AIAgentInteractionStep`. | +| `tool.call` | L5a | Per step where `StepType` is a tool/action invocation. | +| `model.invoke` | L3 | Per LLM call captured in step metadata. | +| `policy.violation` | cross-cutting | Per Einstein Trust Layer policy hit. | +| `agent.handoff` | L4a | Per `AIAgentSessionParticipant` change. | + +Each emitted event includes `_identity` (the Salesforce record `Id`) and +`_timestamp` (record `LastModifiedDate`) for re-import idempotency. + +## Salesforce specifics + +- **Authentication**: JWT Bearer (OAuth 2.0). `SalesforceCredentials` accepts + the private key as a raw PEM, an `env:NAME` reference, or a file path. +- **Token lifetime**: ~2 hours. The adapter re-authenticates automatically + when the cached token expires before the next operation. +- **Rate limits**: a warning is logged when the API daily limit consumption + passes 80%. Salesforce returns the consumption in response headers. +- **Incremental sync**: pass `last_import_timestamp` to + `import_sessions(...)` to fetch only records modified since a watermark. +- **Batch size**: configurable via the `batch_size` constructor arg + (default 200, the SOQL maximum is 2000). + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Recommended for compliance backfills. +adapter = AgentForceAdapter( + credentials=credentials, + capture_config=CaptureConfig.standard(), +) + +# Strip raw message bodies, keep only structural events. +adapter = AgentForceAdapter( + credentials=credentials, + capture_config=CaptureConfig( + l1_agent_io=True, + l4a_environment_config=True, + capture_content=False, + ), +) +``` + +## BYOK + +Salesforce manages its own model keys (Einstein Trust Layer abstracts the +provider). The adapter does not own model API keys. The Salesforce +credentials themselves are intended to live in atlas-app's `byok_credentials` +table once M1.B ships — see `docs/adapters/byok.md`. + +## Replay + +`adapter.serialize_for_replay()` returns a `ReplayableTrace` with all events +captured during the current `import_sessions` call. Replay is a re-emit +operation: the adapter does not re-query Salesforce. diff --git a/docs/adapters/frameworks-autogen.md b/docs/adapters/frameworks-autogen.md new file mode 100644 index 00000000..586273c7 --- /dev/null +++ b/docs/adapters/frameworks-autogen.md @@ -0,0 +1,111 @@ +# AutoGen framework adapter + +`layerlens.instrument.adapters.frameworks.autogen.AutoGenAdapter` instruments +[Microsoft AutoGen](https://github.com/microsoft/autogen) `ConversableAgent` +objects, capturing message exchange, LLM calls, code execution, and group-chat +turns. + +## Install + +```bash +pip install 'layerlens[autogen]' +``` + +Pulls `pyautogen>=0.2,<0.5`. + +## Quick start + +```python +from autogen import AssistantAgent, UserProxyAgent + +from layerlens.instrument.adapters.frameworks.autogen import ( + AutoGenAdapter, + instrument_agents, +) +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="autogen") +adapter = AutoGenAdapter() +adapter.add_sink(sink) +adapter.connect() + +assistant = AssistantAgent(name="assistant", llm_config={"model": "gpt-4o-mini"}) +user = UserProxyAgent(name="user", human_input_mode="NEVER", code_execution_config=False) + +adapter.connect_agents(assistant, user) +user.initiate_chat(assistant, message="What is 2+2?", max_turns=1) + +adapter.disconnect() +sink.close() +``` + +`instrument_agents(*agents)` is the one-line equivalent of the three-line +adapter setup above. + +## What's wrapped + +`adapter.connect_agents(*agents)` monkey-patches the following on each +`ConversableAgent`: + +- `send` — emits `agent.input` for the outgoing message. +- `receive` — emits `agent.output` for the incoming message. +- `generate_reply` — emits `model.invoke` and any `tool.call` events. +- `execute_code_blocks` — emits `agent.code` and `tool.call` for code + execution (when present). + +The originals are stashed on the adapter and restored on `disconnect()`. +A `GroupChatTracer` wires similar hooks onto `GroupChatManager`, and a +`HumanProxyTracer` adds `agent.handoff` semantics for human-in-the-loop +proxies. + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `environment.config` | L4a | First time each agent is seen. | +| `agent.input` | L1 | Every `send`. | +| `agent.output` | L1 | Every `receive`. | +| `agent.action` | L4a | Per `generate_reply` decision. | +| `agent.code` | L2 | When `execute_code_blocks` runs and `l2_agent_code` is enabled. | +| `agent.handoff` | L4a | Group-chat speaker selection / human handoff. | +| `agent.state.change` | cross-cutting | Conversation history mutations. | +| `tool.call` | L5a | Per function-call inside `generate_reply`. | +| `model.invoke` | L3 | Per LLM call. | + +## AutoGen specifics + +- **Multi-agent attribution**: `agent_name`, `recipient_name`, and + `message_seq` (a monotonic counter) are included on every event so the + full chat can be reconstructed in order. +- **Group chats**: `GroupChatTracer` registers as a callback on + `GroupChatManager`, capturing the speaker-selection turns. Pass a + `GroupChatManager` to `connect_agents` alongside the participants. +- **Code execution**: when an agent runs code blocks, the language and + truncated code body emit `agent.code` (only if + `CaptureConfig.l2_agent_code` is enabled). + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Recommended. +adapter = AutoGenAdapter(capture_config=CaptureConfig.standard()) + +# Production-light: skip the verbose code-execution events. +adapter = AutoGenAdapter( + capture_config=CaptureConfig( + l1_agent_io=True, + l3_model_metadata=True, + l4a_environment_config=True, + l5a_tool_calls=True, + l2_agent_code=False, + ), +) +``` + +## BYOK + +AutoGen reads its `llm_config` to instantiate provider clients. The adapter +does not own those keys. For platform-managed BYOK see +`docs/adapters/byok.md` (atlas-app M1.B). diff --git a/docs/adapters/frameworks-crewai.md b/docs/adapters/frameworks-crewai.md new file mode 100644 index 00000000..659f0df1 --- /dev/null +++ b/docs/adapters/frameworks-crewai.md @@ -0,0 +1,122 @@ +# CrewAI framework adapter + +`layerlens.instrument.adapters.frameworks.crewai.CrewAIAdapter` instruments +[CrewAI](https://github.com/joaomdmoura/crewai) crews — multi-agent +collaborations with explicit role assignments and task delegation. + +## Install + +```bash +pip install 'layerlens[crewai]' +``` + +Pulls `crewai>=0.30,<0.90`. + +## Quick start + +```python +from crewai import Agent, Crew, Task + +from layerlens.instrument.adapters.frameworks.crewai import ( + CrewAIAdapter, + instrument_crew, +) +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="crewai") +adapter = CrewAIAdapter() +adapter.add_sink(sink) +adapter.connect() + +# build a one-task crew +researcher = Agent(role="Researcher", goal="Answer", backstory="...") +task = Task(description="What is 2 + 2?", agent=researcher, expected_output="A number") +crew = Crew(agents=[researcher], tasks=[task]) + +instrumented = adapter.instrument_crew(crew) +result = instrumented.kickoff() + +adapter.disconnect() +sink.close() +``` + +The `instrument_crew(crew)` convenience helper wraps the whole flow above: + +```python +instrumented = instrument_crew(crew) # connects + wraps in one call +result = instrumented.kickoff() +``` + +## What's wrapped + +`adapter.instrument_crew(crew)` patches the following on the underlying +crew + agent objects: + +- `Crew.kickoff` — emits `agent.input` + `environment.config` at start, + `agent.output` at completion. +- `Agent.execute_task` — emits `agent.action` + `agent.code` (when enabled) + per task. +- `Agent._invoke_tool` — emits `tool.call` per tool invocation. +- `Agent.llm.call` — emits `model.invoke` per LLM call. +- Crew delegation events via `CrewDelegationTracker` — emits `agent.handoff` + on `Agent.delegate` calls. + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `environment.config` | L4a | First `kickoff` of an instrumented crew. | +| `agent.input` | L1 | Start of every `kickoff`. | +| `agent.output` | L1 | End of every `kickoff` and per-task. | +| `agent.code` | L2 | When `CaptureConfig.l2_agent_code` is true; one per agent. | +| `agent.action` | L4a | Per task execution. | +| `agent.state.change` | cross-cutting | When agent memory or context changes. | +| `agent.handoff` | L4a | When one agent delegates to another. | +| `tool.call` | L5a | Per tool invocation inside a task. | +| `model.invoke` | L3 | Per LLM call from any crew agent. | + +## CrewAI specifics + +- **Multi-agent attribution**: every event payload includes `agent_id`, + `agent_role`, and (when present) `task_id` so the platform can reconstruct + who-did-what across a crew. +- **Memory tracking**: when a `memory_service` is passed to + `CrewAIAdapter(memory_service=...)`, agent short-term memory writes emit + `agent.state.change` with the memory diff. +- **Sequential vs hierarchical**: works for both `Process.sequential` and + `Process.hierarchical`. Hierarchical delegation is captured via the + delegation tracker. + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Recommended. +adapter = CrewAIAdapter(capture_config=CaptureConfig.standard()) + +# Add agent.code (the prompt template / system message of each agent). +adapter = CrewAIAdapter( + capture_config=CaptureConfig( + l1_agent_io=True, + l2_agent_code=True, + l3_model_metadata=True, + l5a_tool_calls=True, + ), +) +``` + +## BYOK + +CrewAI agents instantiate their own LLM clients (LangChain or LiteLLM under +the hood). The CrewAI adapter does not own those keys. For platform-managed +BYOK see `docs/adapters/byok.md` (atlas-app M1.B). + +## Backward compatibility + +```python +from layerlens.instrument.adapters.frameworks.crewai import STRATIXCrewCallback +``` + +`STRATIXCrewCallback` is an alias for `LayerLensCrewCallback` and will be +removed in v2.0. diff --git a/docs/adapters/frameworks-langchain.md b/docs/adapters/frameworks-langchain.md new file mode 100644 index 00000000..f5ffa63e --- /dev/null +++ b/docs/adapters/frameworks-langchain.md @@ -0,0 +1,122 @@ +# LangChain framework adapter + +`layerlens.instrument.adapters.frameworks.langchain.LayerLensCallbackHandler` +implements the LangChain callback interface to emit LayerLens telemetry on +every LLM call, tool invocation, agent step, and chain execution. + +## Install + +```bash +pip install 'layerlens[langchain]' +``` + +Pulls `langchain>=0.2,<0.4` and `langchain-core>=0.2,<0.4`. + +## Quick start + +```python +from langchain_openai import ChatOpenAI +from langchain_core.prompts import ChatPromptTemplate + +from layerlens.instrument.adapters.frameworks.langchain import ( + LayerLensCallbackHandler, + instrument_chain, +) +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="langchain") +handler = LayerLensCallbackHandler() +handler.add_sink(sink) +handler.connect() + +llm = ChatOpenAI(model="gpt-4o-mini", callbacks=[handler]) +prompt = ChatPromptTemplate.from_messages([("user", "{q}")]) +chain = prompt | llm + +result = chain.invoke({"q": "What is 2 + 2?"}, config={"callbacks": [handler]}) + +handler.disconnect() +sink.close() +``` + +The same handler can be passed to any LangChain component that accepts a +`callbacks` list — `ChatOpenAI`, `LLMChain`, `AgentExecutor`, custom tools, etc. + +## What's wrapped + +The handler implements the LangChain callback methods: + +- `on_chat_model_start`, `on_llm_start`, `on_llm_end`, `on_llm_error` +- `on_tool_start`, `on_tool_end`, `on_tool_error` +- `on_agent_action`, `on_agent_finish` +- `on_chain_start`, `on_chain_end`, `on_chain_error` + +Convenience helpers wrap whole objects: + +- `instrument_chain(chain, stratix=...)` — returns a `TracedChain` that injects + the handler into every `invoke`/`batch`/`stream` call. +- `instrument_agent(agent, stratix=...)` — returns a `TracedAgent` that wraps + an `AgentExecutor`. +- `wrap_memory(memory, ...)` — returns a `TracedMemory` that emits + `agent.state.change` on `save_context` / `clear`. + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `model.invoke` | L3 | `on_llm_end` (success) and `on_llm_error` (failure). | +| `tool.call` | L5a | `on_tool_end` and `on_tool_error`. | +| `agent.output` | L4a | `on_agent_finish`. | +| `agent.action` | L4a | `on_agent_action`. | +| `chain.start` / `chain.end` / `chain.error` | L4a | `on_chain_*`. | +| `agent.state.change` | L4a | When a wrapped memory is updated via `wrap_memory`. | + +The `model.invoke` payload includes the resolved provider (extracted from +the LangChain `serialized` dict — `openai`, `anthropic`, `bedrock`, etc.), +model name, prompts, generations, token usage if present, and latency. + +## LangGraph nodes + +When the handler is used inside a LangGraph run, the `metadata.langgraph_node` +field in the LangChain callback metadata is propagated to the +`agent.action` / `chain.start` payloads as `node_name`. This lets the platform +correlate per-node events back to the graph topology — see also the +`langgraph` adapter for full graph instrumentation. + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Production-light: only L1 + protocol discovery + lifecycle. +handler = LayerLensCallbackHandler(capture_config=CaptureConfig.minimal()) + +# Recommended: L1 + L3 + L4a + L5a + L6. +handler = LayerLensCallbackHandler(capture_config=CaptureConfig.standard()) + +# Hand-rolled — keep tokens/costs but redact prompt/response content. +handler = LayerLensCallbackHandler( + capture_config=CaptureConfig( + l3_model_metadata=True, + capture_content=False, + ), +) +``` + +## BYOK + +LangChain manages model API keys via the underlying provider client +(`ChatOpenAI`, `ChatAnthropic`, etc.). The handler does not touch them. For +centrally-managed keys see the platform-side BYOK store in +`docs/adapters/byok.md` (atlas-app M1.B, in flight). + +## Backward compatibility + +Users coming from `ateam` can keep importing the old name: + +```python +from layerlens.instrument.adapters.frameworks.langchain import STRATIXCallbackHandler +``` + +`STRATIXCallbackHandler` is an alias for `LayerLensCallbackHandler` and will +be removed in v2.0. diff --git a/docs/adapters/frameworks-langgraph.md b/docs/adapters/frameworks-langgraph.md new file mode 100644 index 00000000..c46fc9d3 --- /dev/null +++ b/docs/adapters/frameworks-langgraph.md @@ -0,0 +1,114 @@ +# LangGraph framework adapter + +`layerlens.instrument.adapters.frameworks.langgraph.LayerLensLangGraphAdapter` +instruments LangGraph state machines, capturing graph execution, node +transitions, state snapshots, and agent handoffs. + +## Install + +```bash +pip install 'layerlens[langgraph]' +``` + +Pulls `langgraph>=0.2,<0.4`. The `langchain` extra is recommended too if you +use LangChain-based nodes inside the graph. + +## Quick start + +```python +from langgraph.graph import StateGraph, END + +from layerlens.instrument.adapters.frameworks.langgraph import LayerLensLangGraphAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="langgraph") +adapter = LayerLensLangGraphAdapter() +adapter.add_sink(sink) +adapter.connect() + +graph = StateGraph(dict) +graph.add_node("greet", lambda s: {"msg": "hi"}) +graph.set_entry_point("greet") +graph.add_edge("greet", END) +compiled = graph.compile() + +traced = adapter.wrap_graph(compiled) +result = traced.invoke({}) + +adapter.disconnect() +sink.close() +``` + +## What's wrapped + +`adapter.wrap_graph(compiled_graph)` returns a wrapper that proxies +`invoke`, `ainvoke`, `stream`, and `astream`. Each call: + +- Begins a `GraphExecution` and emits `environment.config` + `agent.input`. +- Tracks state hashes before/after to detect mutations. +- On completion emits `agent.output` and (if state changed) `agent.state.change`. +- On error emits `agent.output` with the exception captured. + +Companion utilities: + +- `trace_node(fn)` — decorator that wraps an individual node function and + emits `agent.action` on entry/exit. +- `trace_langgraph_tool(fn)` — decorator for tool nodes; emits `tool.call`. +- `wrap_llm_for_langgraph(llm, ...)` — wraps a LangGraph LLM node so each + invocation emits `model.invoke`. +- `HandoffDetector` — pluggable detector that compares pre/post states for + `__next__` agent transitions and emits `agent.handoff`. + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `environment.config` | L4a | First call into a wrapped graph (per execution). | +| `agent.input` | L1 | Beginning of every wrapped graph execution. | +| `agent.output` | L1 | End of every wrapped graph execution (success or error). | +| `agent.state.change` | cross-cutting | When the state hash changes during execution. | +| `agent.action` | L4a | One per node entry/exit when the node is wrapped with `trace_node`. | +| `tool.call` | L5a | One per tool node wrapped with `trace_langgraph_tool`. | +| `model.invoke` | L3 | One per LLM call wrapped with `wrap_llm_for_langgraph`. | +| `agent.handoff` | L4a | When a `HandoffDetector` is attached and a handoff is detected. | + +## State serialization + +State is serialized via `LangGraphStateAdapter.get_hash` (sha256 of a JSON +form) for the `before_hash` / `after_hash` fields. The full state appears in +`agent.input` / `agent.output` only when `CaptureConfig.capture_content` is +true. Non-JSON-serializable values are coerced via `repr()`; the original +state object is never mutated. + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Recommended: L1 + L3 + L4a + L5a + L6. +adapter = LayerLensLangGraphAdapter(capture_config=CaptureConfig.standard()) + +# Strip prompt/response content but keep structural events. +adapter = LayerLensLangGraphAdapter( + capture_config=CaptureConfig( + l1_agent_io=True, + l4a_environment_config=True, + capture_content=False, + ), +) +``` + +## BYOK + +LangGraph nodes call their underlying providers directly — `ChatOpenAI`, +`ChatAnthropic`, etc. The LangGraph adapter does not own model API keys. +For platform-managed BYOK see `docs/adapters/byok.md` (atlas-app M1.B). + +## Backward compatibility + +```python +from layerlens.instrument.adapters.frameworks.langgraph import STRATIXLangGraphAdapter +``` + +`STRATIXLangGraphAdapter` is an alias for `LayerLensLangGraphAdapter` and will +be removed in v2.0. diff --git a/pyproject.toml b/pyproject.toml index ae6d1dc7..fa7bb567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,19 @@ classifiers = [ [project.optional-dependencies] cli = ["click>=8.0.0"] +# --- Instrument layer: framework adapters (orchestration tier) --- +# Adding any extra below MUST keep the default `pip install layerlens` +# install set unchanged. Verified by `tests/instrument/test_default_install.py`. +langchain = [ + "langchain>=0.2,<0.4; python_version >= '3.9'", + "langchain-core>=0.2,<0.4; python_version >= '3.9'", +] +langgraph = ["langgraph>=0.2,<0.4; python_version >= '3.9'"] +crewai = ["crewai>=0.30,<0.90; python_version >= '3.10'"] +autogen = ["pyautogen>=0.2,<0.5; python_version >= '3.10'"] +agentforce = ["requests>=2.28"] +langfuse-importer = [] # uses stdlib urllib + [project.urls] Homepage = "https://github.com/LayerLens/stratix-python" Repository = "https://github.com/LayerLens/stratix-python" @@ -139,14 +152,21 @@ known-first-party = ["openai", "tests"] "tests/**.py" = ["T201", "T203", "ARG", "B007"] "examples/**.py" = ["T201", "T203"] "src/layerlens/cli/**" = ["T201", "T203"] +# Framework callbacks have signatures dictated by upstream — unused +# arguments are part of the contract, not a code smell. +"src/layerlens/instrument/adapters/frameworks/**.py" = ["ARG002"] [tool.pyright] include = ["src", "tests"] exclude = ["**/__pycache__"] reportMissingTypeStubs = false -# Less strict settings for tests and cli +# Less strict settings for tests, cli, and the dynamic-monkey-patching +# framework adapter code. mypy --strict stays strict for these dirs; +# pyright is relaxed here because it can't follow runtime attribute +# mutation that the framework instrumentation relies on. executionEnvironments = [ { root = "src/layerlens/cli", reportMissingImports = false, reportFunctionMemberAccess = false, reportCallIssue = false, reportArgumentType = false, reportAttributeAccessIssue = false }, + { root = "src/layerlens/instrument/adapters/frameworks", reportPossiblyUnbound = false, reportPossiblyUnboundVariable = false, reportCallIssue = false, reportAttributeAccessIssue = false, reportArgumentType = false, reportMissingImports = false, reportFunctionMemberAccess = false }, { root = "tests", reportGeneralTypeIssues = false, reportOptionalSubscript = false, reportOptionalMemberAccess = false, reportUntypedFunctionDecorator = false, reportUnknownArgumentType = false, reportUnknownMemberType = false, reportUnknownVariableType = false, reportUnnecessaryIsInstance = false, reportUnnecessaryComparison = false, reportArgumentType = false, reportCallIssue = false }, ] diff --git a/samples/instrument/agentforce/__init__.py b/samples/instrument/agentforce/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/samples/instrument/agentforce/main.py b/samples/instrument/agentforce/main.py new file mode 100644 index 00000000..20397589 --- /dev/null +++ b/samples/instrument/agentforce/main.py @@ -0,0 +1,114 @@ +"""Sample: smoke-test the LayerLens AgentForce adapter wiring. + +This sample is **smoke-only**. A real ``import_sessions`` call requires: + +1. A Salesforce Connected App with the JWT Bearer flow configured. +2. A private key (PEM) authorized for the Connected App. +3. A user with read access to the AIAgentSession DMOs. + +Most CI environments don't have those, so this sample only exercises the +adapter's local wiring (instantiate, attach sink, demonstrate the +configuration flow) and exits cleanly. If the required ``SALESFORCE_*`` +env vars are present, it will attempt one ``connect()`` call to verify auth. + +Required environment for the smoke run: + +* (none — the sample exits cleanly without any env vars) + +Optional environment for the auth check: + +* ``SALESFORCE_CLIENT_ID`` — Connected App consumer key. +* ``SALESFORCE_USERNAME`` — Salesforce user the JWT is issued for. +* ``SALESFORCE_PRIVATE_KEY`` — PEM-encoded private key. +* ``SALESFORCE_INSTANCE_URL`` — your org's My Domain URL + (e.g. ``https://example.my.salesforce.com``). +* ``LAYERLENS_STRATIX_API_KEY`` — your LayerLens API key (optional). +* ``LAYERLENS_STRATIX_BASE_URL`` — atlas-app base URL (optional). + +Run:: + + pip install 'layerlens[agentforce]' + python -m samples.instrument.agentforce.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.frameworks.agentforce import ( + AgentForceAdapter, + SalesforceAuthError, + SalesforceCredentials, +) + + +def _have_salesforce_env() -> bool: + return all( + os.environ.get(name) + for name in ( + "SALESFORCE_CLIENT_ID", + "SALESFORCE_USERNAME", + "SALESFORCE_PRIVATE_KEY", + ) + ) + + +def main() -> int: + sink = HttpEventSink( + adapter_name="salesforce_agentforce", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + if not _have_salesforce_env(): + print( + "SALESFORCE_* env vars are not set; running smoke check only.", + file=sys.stderr, + ) + # Smoke check: verify the adapter can be constructed and shut down + # without performing any network I/O. + adapter = AgentForceAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + info = adapter.get_adapter_info() + print(f"Adapter: {info.name} v{info.version} (framework={info.framework})") + sink.close() + return 0 + + credentials = SalesforceCredentials( + client_id=os.environ["SALESFORCE_CLIENT_ID"], + username=os.environ["SALESFORCE_USERNAME"], + private_key=os.environ["SALESFORCE_PRIVATE_KEY"], + instance_url=os.environ.get( + "SALESFORCE_INSTANCE_URL", "https://login.salesforce.com" + ), + ) + + adapter = AgentForceAdapter( + credentials=credentials, + capture_config=CaptureConfig.standard(), + ) + adapter.add_sink(sink) + + try: + adapter.connect() + print("AgentForce adapter authenticated against Salesforce.") + # An import call would look like: + # result = adapter.import_sessions(start_date="2026-04-01", limit=10) + # print(f"Imported {result.events_generated} events.") + except SalesforceAuthError as exc: + print(f"Salesforce auth failed: {exc}", file=sys.stderr) + return 1 + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped (smoke). Check the LayerLens dashboard.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/autogen/__init__.py b/samples/instrument/autogen/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/samples/instrument/autogen/main.py b/samples/instrument/autogen/main.py new file mode 100644 index 00000000..0a8ac9b4 --- /dev/null +++ b/samples/instrument/autogen/main.py @@ -0,0 +1,91 @@ +"""Sample: instrument an AutoGen one-turn conversation with LayerLens. + +Builds an ``AssistantAgent`` + ``UserProxyAgent``, connects them through the +``AutoGenAdapter``, and runs a single-turn ``initiate_chat`` exchange. Each +agent ``send`` / ``receive`` / ``generate_reply`` emits LayerLens events that +ship to atlas-app via ``HttpEventSink``. + +Required environment: + +* ``OPENAI_API_KEY`` — used by AutoGen's ``llm_config``. +* ``LAYERLENS_STRATIX_API_KEY`` — your LayerLens API key (optional). +* ``LAYERLENS_STRATIX_BASE_URL`` — atlas-app base URL (optional). + +Run:: + + pip install 'layerlens[autogen,providers-openai]' + python -m samples.instrument.autogen.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.frameworks.autogen import AutoGenAdapter + + +def main() -> int: + if not os.environ.get("OPENAI_API_KEY"): + print("OPENAI_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + from autogen import AssistantAgent, UserProxyAgent + except ImportError: + print( + "pyautogen not installed. Install with:\n" + " pip install 'layerlens[autogen,providers-openai]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="autogen", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = AutoGenAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + llm_config = { + "config_list": [ + {"model": "gpt-4o-mini", "api_key": os.environ["OPENAI_API_KEY"]}, + ], + "temperature": 0, + "timeout": 30, + } + + assistant = AssistantAgent( + name="assistant", + system_message="You are a concise assistant. Reply with one short sentence.", + llm_config=llm_config, + ) + user = UserProxyAgent( + name="user", + human_input_mode="NEVER", + max_consecutive_auto_reply=0, + code_execution_config=False, + is_termination_msg=lambda _msg: True, + ) + + try: + adapter.connect_agents(assistant, user) + user.initiate_chat(assistant, message="What is 2 + 2?") + last = assistant.last_message(user) + print(f"Response: {last.get('content') if last else '(none)'}") + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/crewai/__init__.py b/samples/instrument/crewai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/samples/instrument/crewai/main.py b/samples/instrument/crewai/main.py new file mode 100644 index 00000000..0d12a9ff --- /dev/null +++ b/samples/instrument/crewai/main.py @@ -0,0 +1,84 @@ +"""Sample: instrument a CrewAI one-task crew with the LayerLens adapter. + +Builds a single-agent, single-task crew, wraps it with ``CrewAIAdapter``, and +runs ``kickoff()``. Each crew kickoff emits ``agent.input``, ``model.invoke``, +``tool.call`` (if any), and ``agent.output`` events that ship to atlas-app +via ``HttpEventSink``. + +Required environment: + +* ``OPENAI_API_KEY`` — used by the underlying CrewAI LLM (the default is + OpenAI; CrewAI honours the standard env var). +* ``LAYERLENS_STRATIX_API_KEY`` — your LayerLens API key (optional). +* ``LAYERLENS_STRATIX_BASE_URL`` — atlas-app base URL (optional). + +Run:: + + pip install 'layerlens[crewai,providers-openai]' + python -m samples.instrument.crewai.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.frameworks.crewai import CrewAIAdapter + + +def main() -> int: + if not os.environ.get("OPENAI_API_KEY"): + print("OPENAI_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + from crewai import Crew, Task, Agent + except ImportError: + print( + "crewai not installed. Install with:\n" + " pip install 'layerlens[crewai,providers-openai]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="crewai", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = CrewAIAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + researcher = Agent( + role="Math Tutor", + goal="Answer arithmetic questions concisely.", + backstory="A concise math tutor who replies with a single number.", + allow_delegation=False, + verbose=False, + ) + task = Task( + description="What is 2 + 2? Reply with just the number.", + agent=researcher, + expected_output="A single integer.", + ) + crew = Crew(agents=[researcher], tasks=[task], verbose=False) + + try: + instrumented = adapter.instrument_crew(crew) + result = instrumented.kickoff() + print(f"Result: {result}") + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/langchain/__init__.py b/samples/instrument/langchain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/samples/instrument/langchain/main.py b/samples/instrument/langchain/main.py new file mode 100644 index 00000000..cb628f0c --- /dev/null +++ b/samples/instrument/langchain/main.py @@ -0,0 +1,83 @@ +"""Sample: instrument LangChain with the LayerLens callback handler. + +Runs a single LCEL chain (prompt | llm) with ``LayerLensCallbackHandler`` +installed on the chain. Every LLM/tool/chain callback fires a LayerLens +event that ships to atlas-app via ``HttpEventSink``. + +Required environment: + +* ``OPENAI_API_KEY`` — used by the underlying ``ChatOpenAI`` model. +* ``LAYERLENS_STRATIX_API_KEY`` — your LayerLens API key (optional). +* ``LAYERLENS_STRATIX_BASE_URL`` — atlas-app base URL (optional). + +Run:: + + pip install 'layerlens[langchain,providers-openai]' + python -m samples.instrument.langchain.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.frameworks.langchain import LayerLensCallbackHandler + + +def main() -> int: + if not os.environ.get("OPENAI_API_KEY"): + print("OPENAI_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + from langchain_openai import ChatOpenAI + from langchain_core.prompts import ChatPromptTemplate + except ImportError: + print( + "langchain / langchain-openai not installed. Install with:\n" + " pip install 'layerlens[langchain,providers-openai]' langchain-openai", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="langchain", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + handler = LayerLensCallbackHandler(capture_config=CaptureConfig.standard()) + handler.add_sink(sink) + handler.connect() + + try: + llm = ChatOpenAI(model="gpt-4o-mini", max_tokens=20, callbacks=[handler]) + prompt = ChatPromptTemplate.from_messages( + [ + ("system", "You are concise."), + ("user", "{question}"), + ], + ) + chain = prompt | llm + + result = chain.invoke( + {"question": "What is 2 + 2?"}, + config={"callbacks": [handler]}, + ) + + text = getattr(result, "content", str(result)) + print(f"Response: {text}") + print(f"Events captured: {len(handler.get_events())}") + finally: + sink.close() + handler.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/langgraph/__init__.py b/samples/instrument/langgraph/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/samples/instrument/langgraph/main.py b/samples/instrument/langgraph/main.py new file mode 100644 index 00000000..5bf1a4c8 --- /dev/null +++ b/samples/instrument/langgraph/main.py @@ -0,0 +1,74 @@ +"""Sample: instrument a LangGraph state machine with LayerLens. + +Builds a tiny one-node ``StateGraph``, wraps it with +``LayerLensLangGraphAdapter.wrap_graph``, and invokes it. The adapter emits +``environment.config`` + ``agent.input`` + ``agent.output`` (and +``agent.state.change`` if the state changed) which ship to atlas-app via +``HttpEventSink``. + +The sample does not require an LLM provider — the node is a pure Python +function — so no API key is needed. This keeps the sample fast, free, and +network-independent so it can be used as a smoke test for the adapter +plumbing itself. + +Run:: + + pip install 'layerlens[langgraph]' + python -m samples.instrument.langgraph.main +""" + +from __future__ import annotations + +import sys +from typing import Any + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.frameworks.langgraph import LayerLensLangGraphAdapter + + +def main() -> int: + try: + from langgraph.graph import END, StateGraph + except ImportError: + print( + "langgraph not installed. Install with:\n" + " pip install 'layerlens[langgraph]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="langgraph", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = LayerLensLangGraphAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + def greet(state: dict[str, Any]) -> dict[str, Any]: + return {"messages": ["hi"], "count": state.get("count", 0) + 1} + + graph: StateGraph = StateGraph(dict) + graph.add_node("greet", greet) + graph.set_entry_point("greet") + graph.add_edge("greet", END) + compiled = graph.compile() + + try: + traced = adapter.wrap_graph(compiled) + result = traced.invoke({"count": 0}) + print(f"Result: {result}") + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/layerlens/instrument/adapters/frameworks/__init__.py b/src/layerlens/instrument/adapters/frameworks/__init__.py new file mode 100644 index 00000000..4cfd328f --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/__init__.py @@ -0,0 +1,32 @@ +"""Framework adapters for the LayerLens Instrument layer. + +Each framework adapter wraps an agent / chain framework's lifecycle to +intercept agent runs, model invocations, tool calls, state changes, and +handoffs, emitting events through the LayerLens telemetry pipeline. + +Adapters available (loaded on demand via :class:`AdapterRegistry`): + +* ``langchain`` — LangChain (callbacks + agent + chain + memory) +* ``langgraph`` — LangGraph (graph hooks + handoff detection + state) +* ``crewai`` — CrewAI (delegation + team metadata) +* ``autogen`` — AutoGen (group chat + lifecycle) +* ``agentforce`` — Salesforce Agentforce (auth, client, event mapping) +* ``semantic_kernel`` — Microsoft Semantic Kernel (filters + lifecycle) +* ``langfuse_importer`` — Langfuse trace import / export +* ``embedding`` — Embedding + vector store instrumentation +* ``openai_agents`` — OpenAI Agents SDK lifecycle +* ``ms_agent_framework`` — MS Agent Framework lifecycle +* ``agno`` — Agno lifecycle +* ``bedrock_agents`` — AWS Bedrock Agents lifecycle +* ``llama_index`` — LlamaIndex lifecycle +* ``google_adk`` — Google ADK lifecycle +* ``strands`` — Strands lifecycle +* ``benchmark_import`` — Benchmark replay-based ingestion +* ``pydantic_ai`` — Pydantic-AI lifecycle +* ``smolagents`` — SmolAgents (HuggingFace) lifecycle +* ``browser_use`` — Browser-Use lifecycle (placeholder; ported in M7) + +Importing this package does NOT import any framework SDK. +""" + +from __future__ import annotations diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/__init__.py b/src/layerlens/instrument/adapters/frameworks/agentforce/__init__.py new file mode 100644 index 00000000..a2ad3043 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/__init__.py @@ -0,0 +1,64 @@ +""" +STRATIX Salesforce Agentforce Adapter + +Full-featured adapter for Salesforce Agentforce agent evaluation: +- Session trace import via Data Cloud SOQL (batch/incremental) +- Agent API REST client (real-time session capture) +- Platform Events subscriber (gRPC Pub/Sub for near-real-time) +- Einstein Trust Layer policy import +- LLM evaluation scenarios (completions, A/B testing, model comparison) + +DMO Objects (Data Cloud): +- AIAgentSession +- AIAgentSessionParticipant +- AIAgentInteraction +- AIAgentInteractionStep +- AIAgentInteractionMessage + +Package: pip install layerlens[salesforce] +""" + +from __future__ import annotations + +from layerlens.instrument.adapters.frameworks.agentforce.auth import ( + NormalizationError, + SalesforceAuthError, + SalesforceConnection, + SalesforceQueryError, + SalesforceCredentials, +) +from layerlens.instrument.adapters.frameworks.agentforce.client import AgentApiClient +from layerlens.instrument.adapters.frameworks.agentforce.events import PlatformEventSubscriber +from layerlens.instrument.adapters.frameworks.agentforce.mapper import AgentApiMapper +from layerlens.instrument.adapters.frameworks.agentforce.adapter import AgentForceAdapter +from layerlens.instrument.adapters.frameworks.agentforce.importer import ImportResult, AgentForceImporter +from layerlens.instrument.adapters.frameworks.agentforce.llm_eval import EinsteinEvaluator +from layerlens.instrument.adapters.frameworks.agentforce.normalizer import AgentForceNormalizer +from layerlens.instrument.adapters.frameworks.agentforce.trust_layer import TrustLayerImporter + +__all__ = [ + # Core adapter + "AgentForceAdapter", + # Auth + "SalesforceAuthError", + "SalesforceConnection", + "SalesforceCredentials", + "SalesforceQueryError", + "NormalizationError", + # Import + "AgentForceImporter", + "AgentForceNormalizer", + "ImportResult", + # Agent API + "AgentApiClient", + "AgentApiMapper", + # Trust Layer + "TrustLayerImporter", + # Platform Events + "PlatformEventSubscriber", + # Evaluation + "EinsteinEvaluator", +] + +# Registry lazy-loading convention +ADAPTER_CLASS = AgentForceAdapter diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/adapter.py b/src/layerlens/instrument/adapters/frameworks/agentforce/adapter.py new file mode 100644 index 00000000..f2b2c016 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/adapter.py @@ -0,0 +1,188 @@ +""" +AgentForce Adapter + +BaseAdapter-compliant wrapper for AgentForce trace import. +Provides lifecycle management, circuit breaker protection, +CaptureConfig filtering, and health reporting. +""" + +from __future__ import annotations + +import uuid +import logging +from typing import Any + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters.frameworks.agentforce.auth import ( + SalesforceAuthError, + SalesforceConnection, + SalesforceCredentials, +) +from layerlens.instrument.adapters.frameworks.agentforce.importer import ImportResult, AgentForceImporter +from layerlens.instrument.adapters.frameworks.agentforce.normalizer import AgentForceNormalizer + +logger = logging.getLogger(__name__) + + +class AgentForceAdapter(BaseAdapter): + """ + BaseAdapter wrapper for AgentForce trace import. + + Provides the standard STRATIX adapter lifecycle (connect/disconnect/health_check) + around the AgentForce importer, routing imported events through the BaseAdapter + circuit breaker and CaptureConfig pipeline. + + Usage: + adapter = AgentForceAdapter(stratix=stratix, credentials=credentials) + adapter.connect() + result = adapter.import_sessions(start_date="2026-02-21") + adapter.disconnect() + """ + + FRAMEWORK = "salesforce_agentforce" + VERSION = "0.1.0" + # ``frameworks/agentforce/models.py`` line 17 imports + # ``from pydantic import Field, BaseModel`` only — both names exist + # identically under v1 and v2. No v2-only decorators + # (field_validator/model_validator) appear anywhere in the + # agentforce subpackage. Salesforce Agentforce itself is a remote + # REST API, not a Python library, so there is no framework-side + # Pydantic dependency to constrain. + requires_pydantic = PydanticCompat.V1_OR_V2 + + def __init__( + self, + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + credentials: SalesforceCredentials | None = None, + connection: SalesforceConnection | None = None, + batch_size: int = 200, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._credentials = credentials + self._connection = connection + self._normalizer = AgentForceNormalizer() + self._importer: AgentForceImporter | None = None + self._batch_size = batch_size + + def connect(self) -> None: + """Authenticate with Salesforce and prepare the importer.""" + if self._connection is None: + if self._credentials is None: + raise SalesforceAuthError("Either 'credentials' or 'connection' must be provided") + self._connection = SalesforceConnection(credentials=self._credentials) + + if self._credentials and self._credentials.is_expired: + self._connection.authenticate() + + self._importer = AgentForceImporter( + connection=self._connection, + normalizer=self._normalizer, + batch_size=self._batch_size, + ) + + self._connected = True + self._status = AdapterStatus.HEALTHY + logger.info("AgentForce adapter connected") + + def disconnect(self) -> None: + """Disconnect and release resources.""" + self._importer = None + self._connected = False + self._status = AdapterStatus.DISCONNECTED + logger.info("AgentForce adapter disconnected") + + def health_check(self) -> AdapterHealth: + """Return adapter health, including Salesforce connection status.""" + message = None + if self._connection and self._credentials and self._credentials.is_expired: + message = "Salesforce token expired, will re-authenticate on next operation" + + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + adapter_version=self.VERSION, + message=message, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name="AgentForceAdapter", + version=self.VERSION, + framework=self.FRAMEWORK, + capabilities=[ + AdapterCapability.TRACE_MODELS, + AdapterCapability.TRACE_TOOLS, + ], + description="LayerLens adapter for Salesforce AgentForce trace import", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + return ReplayableTrace( + adapter_name="AgentForceAdapter", + framework=self.FRAMEWORK, + trace_id=str(uuid.uuid4()), + events=list(self._trace_events), + state_snapshots=[], + config={ + "capture_config": self._capture_config.model_dump(), + }, + ) + + def import_sessions( + self, + start_date: str | None = None, + end_date: str | None = None, + agent_type: str | None = None, + channel_type: str | None = None, + limit: int | None = None, + last_import_timestamp: str | None = None, + ) -> ImportResult: + """ + Import AgentForce sessions and emit events through the adapter pipeline. + + Events are routed through ``emit_dict_event()`` for circuit breaker + and CaptureConfig protection. + + Returns: + ImportResult summary. + """ + if not self._connected or not self._importer: + raise RuntimeError("Adapter not connected. Call connect() first.") + + events, result = self._importer.import_sessions( + start_date=start_date, + end_date=end_date, + agent_type=agent_type, + channel_type=channel_type, + limit=limit, + last_import_timestamp=last_import_timestamp, + ) + + # Route each event through BaseAdapter pipeline + emitted = 0 + for event in events: + event_type = event.get("event_type", "") + payload = event.get("payload", {}) + # Add identity and timestamp to payload for downstream consumers + if "identity" in event: + payload["_identity"] = event["identity"] + if "timestamp" in event: + payload["_timestamp"] = event["timestamp"] + + self.emit_dict_event(event_type, payload) + emitted += 1 + + result.events_generated = emitted + return result diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/auth.py b/src/layerlens/instrument/adapters/frameworks/agentforce/auth.py new file mode 100644 index 00000000..1009c2fb --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/auth.py @@ -0,0 +1,328 @@ +""" +Salesforce OAuth 2.0 JWT Bearer Authentication + +Implements the JWT Bearer flow for server-to-server authentication +with Salesforce Data Cloud. Includes retry with exponential backoff, +timeouts, and credential masking. +""" + +from __future__ import annotations + +import os +import time +import logging +from typing import Any +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# Timeout defaults (seconds) +_AUTH_TIMEOUT = 30 +_QUERY_TIMEOUT = 60 + +# Retry defaults +_MAX_RETRIES = 3 +_RETRY_BASE_DELAY = 1.0 # seconds +_RETRY_MAX_DELAY = 30.0 # seconds + +# Salesforce access token lifetime (conservative; actual is ~2 hours) +_TOKEN_LIFETIME_S = 3600 + +# Rate limit warning threshold (percentage of API limit consumed) +_RATE_LIMIT_WARN_THRESHOLD = 0.8 + + +class SalesforceAuthError(Exception): + """Raised when Salesforce authentication fails.""" + + def __init__(self, message: str, status_code: int | None = None, endpoint: str = "") -> None: + self.status_code = status_code + self.endpoint = endpoint + super().__init__(message) + + +class SalesforceQueryError(Exception): + """Raised when a SOQL query fails.""" + + def __init__(self, message: str, status_code: int | None = None, soql: str = "") -> None: + self.status_code = status_code + self.soql = soql + super().__init__(message) + + +class NormalizationError(Exception): + """Raised when normalization of AgentForce records fails.""" + + pass + + +@dataclass +class SalesforceCredentials: + """Salesforce connection credentials.""" + + client_id: str + username: str + private_key: str # PEM-encoded private key or env var name + instance_url: str = "https://login.salesforce.com" + access_token: str | None = None + token_expiry: float = 0.0 + + @property + def is_expired(self) -> bool: + return time.time() >= self.token_expiry + + def resolve_private_key(self) -> str: + """Resolve the private key from env var, file path, or raw PEM string.""" + key = self.private_key + # Check env var reference + if key.startswith("$") or key.startswith("env:"): + env_name = key.lstrip("$").removeprefix("env:") + resolved = os.environ.get(env_name, "") + if not resolved: + raise SalesforceAuthError( + f"Environment variable '{env_name}' not set for private key" + ) + return resolved + # Check file path + if os.path.isfile(key): + with open(key) as f: + return f.read() + # Assume raw PEM + return key + + def __repr__(self) -> str: + return ( + f"SalesforceCredentials(" + f"client_id='{self.client_id[:8]}...', " + f"username='{self.username}', " + f"instance_url='{self.instance_url}', " + f"private_key='***REDACTED***', " + f"access_token={'***REDACTED***' if self.access_token else 'None'}, " + f"is_expired={self.is_expired})" + ) + + +@dataclass +class SalesforceConnection: + """Active Salesforce connection with retry and timeout support.""" + + credentials: SalesforceCredentials + instance_url: str = "" + api_version: str = "v60.0" + auth_timeout: int = _AUTH_TIMEOUT + query_timeout: int = _QUERY_TIMEOUT + max_retries: int = _MAX_RETRIES + + def authenticate(self) -> None: + """Authenticate using JWT Bearer flow with retry.""" + import jwt + import requests # type: ignore[import-untyped,unused-ignore] + + resolved_key = self.credentials.resolve_private_key() + + # Build JWT + now = int(time.time()) + payload = { + "iss": self.credentials.client_id, + "sub": self.credentials.username, + "aud": self.credentials.instance_url, + "exp": now + 300, + } + token = jwt.encode(payload, resolved_key, algorithm="RS256") + + endpoint = f"{self.credentials.instance_url}/services/oauth2/token" + last_error: Exception | None = None + + for attempt in range(self.max_retries): + try: + response = requests.post( + endpoint, + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": token, + }, + timeout=self.auth_timeout, + ) + response.raise_for_status() + data = response.json() + + self.credentials.access_token = data["access_token"] + self.instance_url = data["instance_url"] + self.credentials.token_expiry = now + _TOKEN_LIFETIME_S + logger.info("Authenticated with Salesforce: %s", self.instance_url) + return + except requests.exceptions.Timeout as e: + last_error = e + logger.warning( + "Salesforce auth timeout (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + except requests.exceptions.HTTPError as e: + status = e.response.status_code if e.response is not None else None + # Don't retry 4xx (client errors) except 429 (rate limit) + if status is not None and 400 <= status < 500 and status != 429: + raise SalesforceAuthError( + f"Salesforce authentication failed (HTTP {status}). " + f"Check credentials and re-authenticate using `stratix agentforce connect`." + f" " + f"Endpoint: {endpoint}", + status_code=status, + endpoint=endpoint, + ) from e + last_error = e + logger.warning( + "Salesforce auth HTTP error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + except requests.exceptions.RequestException as e: + last_error = e + logger.warning( + "Salesforce auth request error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + + # Exponential backoff + if attempt < self.max_retries - 1: + delay = min( + _RETRY_BASE_DELAY * (2**attempt), + _RETRY_MAX_DELAY, + ) + time.sleep(delay) + + raise SalesforceAuthError( + f"Salesforce authentication failed after {self.max_retries} attempts. " + f"Last error: {last_error}. " + f"Re-authenticate using `stratix agentforce connect`. " + f"Endpoint: {endpoint}", + endpoint=endpoint, + ) + + @staticmethod + def _check_rate_limit(response_headers: dict[str, Any]) -> None: + """Parse Sforce-Limit-Info header and warn if approaching limits. + + Salesforce returns ``Sforce-Limit-Info: api-usage=25/15000`` on every + API response. We log a warning when usage exceeds the configured + threshold so operators can react before hitting hard limits. + """ + limit_info = response_headers.get("Sforce-Limit-Info", "") + if not limit_info: + return + try: + # Format: "api-usage=USED/LIMIT" + usage_part = limit_info.split("=", 1)[1] if "=" in limit_info else "" + if "/" in usage_part: + used_str, total_str = usage_part.split("/", 1) + used, total = int(used_str), int(total_str) + if total > 0 and used / total >= _RATE_LIMIT_WARN_THRESHOLD: + logger.warning( + "Salesforce API rate limit warning: %d/%d (%.0f%%) consumed", + used, + total, + (used / total) * 100, + ) + except (ValueError, IndexError): + # Malformed header — ignore silently + pass + + def query(self, soql: str) -> list[dict[str, Any]]: + """Execute a SOQL query with retry, timeout, and pagination.""" + if self.credentials.is_expired: + self.authenticate() + + import requests + + url = f"{self.instance_url}/services/data/{self.api_version}/query" + headers = { + "Authorization": f"Bearer {self.credentials.access_token}", + "Content-Type": "application/json", + } + + records: list[dict[str, Any]] = [] + params: dict[str, str] | None = {"q": soql} + + while True: + last_error: Exception | None = None + success = False + + for attempt in range(self.max_retries): + try: + response = requests.get( + url, + headers=headers, + params=params, + timeout=self.query_timeout, + ) + response.raise_for_status() + + # Check Salesforce API rate limits + self._check_rate_limit(response.headers) # type: ignore[arg-type] + + data = response.json() + + records.extend(data.get("records", [])) + + # Handle pagination + next_url = data.get("nextRecordsUrl") + if next_url: + url = f"{self.instance_url}{next_url}" + params = None # Pagination URL includes query params + success = True + break + + except requests.exceptions.Timeout as e: + last_error = e + logger.warning( + "Salesforce query timeout (attempt %d/%d)", + attempt + 1, + self.max_retries, + ) + except requests.exceptions.HTTPError as e: + status = e.response.status_code if e.response is not None else None + if status is not None and 400 <= status < 500 and status != 429: + raise SalesforceQueryError( + f"SOQL query failed (HTTP {status})", + status_code=status, + soql=soql[:200], + ) from e + last_error = e + logger.warning( + "Salesforce query HTTP error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + except requests.exceptions.RequestException as e: + last_error = e + logger.warning( + "Salesforce query request error (attempt %d/%d): %s", + attempt + 1, + self.max_retries, + e, + ) + + if attempt < self.max_retries - 1: + delay = min( + _RETRY_BASE_DELAY * (2**attempt), + _RETRY_MAX_DELAY, + ) + time.sleep(delay) + + if not success: + raise SalesforceQueryError( + f"SOQL query failed after {self.max_retries} attempts. " + f"Last error: {last_error}", + soql=soql[:200], + ) + + # If no next page, we're done + if not data.get("nextRecordsUrl"): + break + + return records diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/client.py b/src/layerlens/instrument/adapters/frameworks/agentforce/client.py new file mode 100644 index 00000000..b2ab2b53 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/client.py @@ -0,0 +1,334 @@ +""" +Salesforce Agent API REST Client + +Provides a typed client for the Salesforce Agent API: +- Session creation and lifecycle management +- Synchronous and streaming message exchange +- Response parsing with action and guardrail extraction + +Reference: https://developer.salesforce.com/docs/ai/agentforce/guide/agent-api.html +""" + +from __future__ import annotations + +import json +import time +import logging +from typing import Any +from collections.abc import Generator + +from layerlens.instrument.adapters.frameworks.agentforce.auth import ( + SalesforceConnection, + SalesforceQueryError, +) +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + AgentApiMessage, + AgentApiSession, +) + +logger = logging.getLogger(__name__) + +# Agent API path prefix +_AGENT_API_PREFIX = "/services/data/{version}/agent" + +# Default timeout for Agent API calls (seconds) +_API_TIMEOUT = 30 + +# Maximum response text length to capture (prevent memory bloat) +_MAX_RESPONSE_LENGTH = 50_000 + + +class AgentApiClient: + """ + REST client for the Salesforce Agent API. + + Wraps session creation, message exchange, and response parsing. + All methods use the authenticated ``SalesforceConnection`` for + token management and retry logic. + + Usage: + client = AgentApiClient(connection=connection) + session = client.create_session(agent_name="Service_Agent") + response = client.send_message(session.session_id, "How do I reset my password?") + client.end_session(session.session_id) + """ + + def __init__( + self, + connection: SalesforceConnection, + api_timeout: int = _API_TIMEOUT, + ) -> None: + self._connection = connection + self._api_timeout = api_timeout + self._base_url = "" + + @property + def base_url(self) -> str: + """Build the Agent API base URL from the connection.""" + if not self._base_url: + instance = self._connection.instance_url + version = self._connection.api_version + self._base_url = f"{instance}{_AGENT_API_PREFIX.format(version=version)}" + return self._base_url + + def create_session( + self, + agent_name: str, + context: dict[str, Any] | None = None, + ) -> AgentApiSession: + """ + Create a new Agent API session. + + Args: + agent_name: Name of the Agentforce agent to connect to. + context: Optional context variables for the session. + + Returns: + AgentApiSession with the session ID and initial state. + + Raises: + SalesforceQueryError: If the API call fails. + """ + import requests # type: ignore[import-untyped,unused-ignore] + + if not agent_name or not agent_name.strip(): + raise ValueError("agent_name must be a non-empty string") + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + url = f"{self.base_url}/sessions" + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + body: dict[str, Any] = {"agentName": agent_name} + if context: + body["context"] = context + + try: + response = requests.post( + url, + headers=headers, + json=body, + timeout=self._api_timeout, + ) + response.raise_for_status() + data = response.json() + + return AgentApiSession( + session_id=data.get("sessionId", ""), + agent_name=agent_name, + status="active", + created_at=data.get("createdAt"), + ) + except requests.exceptions.RequestException as e: + raise SalesforceQueryError( + f"Failed to create Agent API session: {e}", + status_code=getattr(getattr(e, "response", None), "status_code", None), + ) from e + + def send_message( + self, + session_id: str, + message: str, + stream: bool = False, + ) -> AgentApiMessage | Generator[str, None, None]: + """ + Send a message to an active Agent API session. + + Args: + session_id: The session ID from ``create_session()``. + message: User message text to send. + stream: If True, return a generator of streaming response chunks. + + Returns: + AgentApiMessage with the agent response, or a generator if streaming. + + Raises: + SalesforceQueryError: If the API call fails. + """ + if not session_id or not session_id.strip(): + raise ValueError("session_id must be a non-empty string") + if not message or not message.strip(): + raise ValueError("message must be a non-empty string") + + import requests + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + url = f"{self.base_url}/sessions/{session_id}/messages" + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + if stream: + headers["Accept"] = "text/event-stream" + + body = {"message": {"text": message}} + + try: + response = requests.post( + url, + headers=headers, + json=body, + timeout=self._api_timeout, + stream=stream, + ) + response.raise_for_status() + + if stream: + return self._stream_response(response) + + return self._parse_message_response(response.json()) + + except requests.exceptions.RequestException as e: + raise SalesforceQueryError( + f"Failed to send Agent API message: {e}", + status_code=getattr(getattr(e, "response", None), "status_code", None), + ) from e + + def end_session(self, session_id: str) -> None: + """ + End an active Agent API session. + + Args: + session_id: The session ID to end. + + Raises: + SalesforceQueryError: If the API call fails. + """ + if not session_id or not session_id.strip(): + raise ValueError("session_id must be a non-empty string") + + import requests + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + url = f"{self.base_url}/sessions/{session_id}" + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + + try: + response = requests.delete( + url, + headers=headers, + timeout=self._api_timeout, + ) + response.raise_for_status() + except requests.exceptions.RequestException as e: + raise SalesforceQueryError( + f"Failed to end Agent API session: {e}", + status_code=getattr(getattr(e, "response", None), "status_code", None), + ) from e + + def capture_session( + self, + agent_name: str, + messages: list[str], + context: dict[str, Any] | None = None, + ) -> AgentApiSession: + """ + Convenience method: create session, send all messages, end session. + + Returns an ``AgentApiSession`` with all messages and responses. + + Args: + agent_name: Agentforce agent name. + messages: List of user messages to send sequentially. + context: Optional session context. + + Returns: + Complete AgentApiSession with all exchanged messages. + """ + session = self.create_session(agent_name, context) + all_messages: list[AgentApiMessage] = [] + + for msg_text in messages: + # Record user message + all_messages.append(AgentApiMessage(role="user", content=msg_text)) + + # Send and capture response + response = self.send_message(session.session_id, msg_text) + if isinstance(response, AgentApiMessage): + all_messages.append(response) + + self.end_session(session.session_id) + + session.messages = all_messages + session.status = "ended" + session.ended_at = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + return session + + # --- Internal helpers --- + + @staticmethod + def _parse_message_response(data: dict[str, Any]) -> AgentApiMessage: + """Parse a synchronous Agent API message response.""" + messages = data.get("messages", []) + if not messages: + return AgentApiMessage( + role="agent", + content=data.get("text", ""), + timestamp=data.get("timestamp"), + ) + + # Take the last agent message + last = messages[-1] + actions = [] + guardrails = [] + + # Extract actions if present + for action in data.get("actions", []): + actions.append( + { + "name": action.get("name", "unknown"), + "parameters": action.get("parameters", {}), + "result": action.get("result"), + } + ) + + # Extract guardrail results if present + for gr in data.get("guardrailResults", []): + guardrails.append( + { + "name": gr.get("name", "unknown"), + "triggered": gr.get("triggered", False), + "message": gr.get("message"), + } + ) + + return AgentApiMessage( + id=last.get("id"), + role="agent", + content=str(last.get("text", ""))[:_MAX_RESPONSE_LENGTH], + timestamp=last.get("timestamp"), + topic=data.get("topic"), + actions=actions, + guardrail_results=guardrails, + ) + + @staticmethod + def _stream_response(response: Any) -> Generator[str, None, None]: + """Parse a streaming Agent API response (SSE format).""" + try: + for line in response.iter_lines(decode_unicode=True): + if not line: + continue + if line.startswith("data: "): + data_str = line[6:] + if data_str.strip() == "[DONE]": + return + try: + chunk = json.loads(data_str) + text = chunk.get("text", "") + if text: + yield text + except json.JSONDecodeError: + logger.debug("Failed to parse SSE chunk: %s", data_str[:100]) + finally: + response.close() diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/events.py b/src/layerlens/instrument/adapters/frameworks/agentforce/events.py new file mode 100644 index 00000000..c17cf9c9 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/events.py @@ -0,0 +1,330 @@ +""" +Salesforce Platform Events Subscriber + +Subscribes to Salesforce Platform Events via the gRPC Pub/Sub API +for near-real-time Agentforce session capture. + +Supports: +- gRPC Pub/Sub API subscription (AgentSession__e) +- Automatic reconnection with exponential backoff +- Event replay from a specific replay ID +- Graceful shutdown with pending event flush + +Reference: https://developer.salesforce.com/docs/platform/pub-sub-api/overview +""" + +from __future__ import annotations + +import time +import logging +import threading +from typing import Any +from collections.abc import Callable + +from layerlens.instrument.adapters.frameworks.agentforce.auth import SalesforceConnection +from layerlens.instrument.adapters.frameworks.agentforce.models import AgentSessionEvent + +logger = logging.getLogger(__name__) + +# Default Platform Event channel +_DEFAULT_CHANNEL = "/event/AgentSession__e" + +# Reconnection backoff constants +_RECONNECT_BASE_DELAY = 1.0 +_RECONNECT_MAX_DELAY = 60.0 +_MAX_RECONNECT_ATTEMPTS = 10 + +# Subscriber batch size +_BATCH_SIZE = 100 + + +class PlatformEventSubscriber: + """ + Subscribe to Salesforce Platform Events for real-time Agentforce capture. + + Uses the Salesforce gRPC Pub/Sub API to receive events as they occur, + with automatic reconnection and replay support. + + Usage: + subscriber = PlatformEventSubscriber( + connection=connection, + on_event=handle_event, + ) + subscriber.start() + # ... later ... + subscriber.stop() + """ + + def __init__( + self, + connection: SalesforceConnection, + on_event: Callable[[AgentSessionEvent], None] | None = None, + channel: str = _DEFAULT_CHANNEL, + replay_id: str | None = None, + ) -> None: + """ + Initialize the Platform Events subscriber. + + Args: + connection: Authenticated Salesforce connection. + on_event: Callback invoked for each received event. + channel: Platform Event channel to subscribe to. + replay_id: Optional replay ID to resume from. + """ + self._connection = connection + self._on_event = on_event + self._channel = channel + self._replay_id = replay_id + self._running = False + self._thread: threading.Thread | None = None + self._reconnect_attempts = 0 + self._events_received = 0 + self._last_replay_id: str | None = replay_id + + @property + def is_running(self) -> bool: + """Whether the subscriber is actively listening.""" + return self._running + + @property + def events_received(self) -> int: + """Total events received since start.""" + return self._events_received + + @property + def last_replay_id(self) -> str | None: + """Last processed replay ID (for resume on restart).""" + return self._last_replay_id + + def start(self) -> None: + """ + Start the Platform Events subscriber in a background thread. + + The subscriber will attempt to connect and begin receiving events. + On connection failure, it retries with exponential backoff. + """ + if self._running: + logger.warning("Platform Events subscriber already running") + return + + self._running = True + self._thread = threading.Thread( + target=self._subscribe_loop, + name="stratix-sf-events", + daemon=True, + ) + self._thread.start() + logger.info( + "Platform Events subscriber started on channel: %s", + self._channel, + ) + + def stop(self) -> None: + """ + Stop the Platform Events subscriber. + + Signals the background thread to stop and waits for graceful shutdown. + """ + self._running = False + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=5.0) + self._thread = None + logger.info( + "Platform Events subscriber stopped. Events received: %d", + self._events_received, + ) + + def _subscribe_loop(self) -> None: + """Main subscription loop with reconnection logic.""" + while self._running: + try: + self._subscribe() + except Exception as e: + # ``self._running`` can flip concurrently from ``stop()`` — + # mypy can't see the cross-thread mutation, so it thinks the + # break is unreachable inside ``while self._running:``. It's + # not. + if not self._running: + break # type: ignore[unreachable] + self._reconnect_attempts += 1 + if self._reconnect_attempts > _MAX_RECONNECT_ATTEMPTS: + logger.error( + "Platform Events subscriber exceeded max reconnect attempts (%d). Stopping.", # noqa: E501 + _MAX_RECONNECT_ATTEMPTS, + ) + self._running = False + break + + delay = min( + _RECONNECT_BASE_DELAY * (2 ** (self._reconnect_attempts - 1)), + _RECONNECT_MAX_DELAY, + ) + logger.warning( + "Platform Events connection lost (attempt %d/%d): %s. Retrying in %.1fs.", + self._reconnect_attempts, + _MAX_RECONNECT_ATTEMPTS, + str(e)[:200], + delay, + ) + time.sleep(delay) + + def _subscribe(self) -> None: + """ + Subscribe to the Platform Event channel. + + This method uses HTTP long-polling as a fallback when the gRPC + Pub/Sub API client is not available. For production use with + high-volume events, the gRPC client is recommended. + """ + # Attempt gRPC Pub/Sub API first + try: + self._subscribe_grpc() + return + except ImportError: + logger.info("gRPC Pub/Sub client not available. Falling back to CometD polling.") + except Exception as e: + logger.warning("gRPC subscription failed: %s. Falling back.", e) + + # Fallback: CometD / HTTP long-polling + self._subscribe_cometd() + + def _subscribe_grpc(self) -> None: + """ + Subscribe using the Salesforce gRPC Pub/Sub API. + + Requires the ``grpcio`` and ``avro`` packages. + """ + # Import gRPC dependencies (optional) + import grpc # type: ignore[import-untyped,unused-ignore] # noqa: F401 + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + # gRPC Pub/Sub API endpoint + pubsub_endpoint = self._connection.instance_url.replace("https://", "") + ":443" + + logger.info("Connecting to gRPC Pub/Sub API: %s", pubsub_endpoint) + + # NOTE: Full gRPC stub implementation requires the Salesforce + # pub-sub proto definitions. This is a structural placeholder + # that demonstrates the connection pattern. Production code + # should use the salesforce-pubsub package. + raise NotImplementedError( + "Full gRPC Pub/Sub implementation requires salesforce-pubsub package. " + "Install: pip install salesforce-pubsub" + ) + + def _subscribe_cometd(self) -> None: + """ + Subscribe using CometD long-polling (fallback). + + Uses the Streaming API (/cometd) endpoint for Platform Events. + Lower throughput than gRPC but works without additional dependencies. + """ + import requests # type: ignore[import-untyped,unused-ignore] + + if self._connection.credentials.is_expired: + self._connection.authenticate() + + base_url = self._connection.instance_url + api_version = self._connection.api_version + cometd_url = f"{base_url}/cometd/{api_version.lstrip('v')}" + + headers = { + "Authorization": f"Bearer {self._connection.credentials.access_token}", + "Content-Type": "application/json", + } + + # CometD handshake + handshake_payload = [ + { + "channel": "/meta/handshake", + "version": "1.0", + "supportedConnectionTypes": ["long-polling"], + "minimumVersion": "1.0", + } + ] + + try: + resp = requests.post( + cometd_url, + headers=headers, + json=handshake_payload, + timeout=30, + ) + resp.raise_for_status() + handshake_data = resp.json() + client_id = handshake_data[0].get("clientId") + if not client_id: + raise RuntimeError("CometD handshake failed: no clientId") + + # Subscribe to channel + subscribe_payload = [ + { + "channel": "/meta/subscribe", + "clientId": client_id, + "subscription": self._channel, + } + ] + if self._replay_id: + subscribe_payload[0]["ext"] = { + "replay": {self._channel: self._replay_id}, + } + + resp = requests.post( + cometd_url, + headers=headers, + json=subscribe_payload, + timeout=30, + ) + resp.raise_for_status() + + # Reset reconnect attempts on successful connection + self._reconnect_attempts = 0 + + # Long-polling loop + while self._running: + connect_payload = [ + { + "channel": "/meta/connect", + "clientId": client_id, + "connectionType": "long-polling", + } + ] + resp = requests.post( + cometd_url, + headers=headers, + json=connect_payload, + timeout=120, + ) + resp.raise_for_status() + + for msg in resp.json(): + channel = msg.get("channel", "") + if channel == self._channel: + self._handle_event(msg.get("data", {})) + + except requests.exceptions.RequestException as e: + raise RuntimeError(f"CometD connection error: {e}") from e + + def _handle_event(self, data: dict[str, Any]) -> None: + """Process a received Platform Event.""" + try: + event = AgentSessionEvent( + session_id=data.get("SessionId__c", ""), + agent_name=data.get("AgentName__c"), + topic_name=data.get("TopicName__c"), + actions_taken=data.get("ActionsTaken__c"), + response_text=data.get("ResponseText__c"), + trust_layer_flags=data.get("TrustLayerFlags__c"), + replay_id=str(data.get("event", {}).get("replayId", "")), + ) + + self._events_received += 1 + self._last_replay_id = event.replay_id + + if self._on_event: + self._on_event(event) + + except Exception as e: + logger.warning("Failed to process Platform Event: %s", e) diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/importer.py b/src/layerlens/instrument/adapters/frameworks/agentforce/importer.py new file mode 100644 index 00000000..8438e854 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/importer.py @@ -0,0 +1,268 @@ +""" +AgentForce Trace Importer + +Imports AgentForce Session Tracing data from Salesforce Data Cloud +and normalizes it to STRATIX canonical events. + +Supports: +- Batch import (date range filter) +- Incremental import (timestamp-based) +- Session, participant, interaction, step, and message extraction +""" + +from __future__ import annotations + +import re +import logging +from typing import Any +from dataclasses import field, dataclass + +from layerlens.instrument.adapters.frameworks.agentforce.auth import SalesforceConnection +from layerlens.instrument.adapters.frameworks.agentforce.normalizer import AgentForceNormalizer + +# Regex for validating ISO 8601 date strings (YYYY-MM-DD) +_DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") +# Regex for validating ISO 8601 timestamp strings +_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") +# Regex for Salesforce record IDs (exactly 15 or 18 char alphanumeric) +_SF_ID_RE = re.compile(r"^[a-zA-Z0-9]{15}(?:[a-zA-Z0-9]{3})?$") + +logger = logging.getLogger(__name__) + + +@dataclass +class ImportResult: + """Result of an AgentForce import operation.""" + + sessions_imported: int = 0 + participants_imported: int = 0 + interactions_imported: int = 0 + steps_imported: int = 0 + messages_imported: int = 0 + events_generated: int = 0 + errors: list[str] = field(default_factory=list) + + @property + def total_records(self) -> int: + return ( + self.sessions_imported + + self.participants_imported + + self.interactions_imported + + self.steps_imported + + self.messages_imported + ) + + +class AgentForceImporter: + """ + Import AgentForce traces from Salesforce Data Cloud. + + Usage: + connection = SalesforceConnection(credentials) + connection.authenticate() + importer = AgentForceImporter(connection) + events, result = importer.import_sessions( + start_date="2026-02-21", + end_date="2026-02-28", + ) + """ + + def __init__( + self, + connection: SalesforceConnection, + normalizer: AgentForceNormalizer | None = None, + batch_size: int = 200, + ) -> None: + self._connection = connection + self._normalizer = normalizer or AgentForceNormalizer() + self._batch_size = batch_size + + def import_sessions( + self, + start_date: str | None = None, + end_date: str | None = None, + agent_type: str | None = None, + channel_type: str | None = None, + limit: int | None = None, + last_import_timestamp: str | None = None, + ) -> tuple[list[dict[str, Any]], ImportResult]: + """ + Import AgentForce sessions and all related records. + + Args: + start_date: Import sessions starting from this date (ISO 8601) + end_date: Import sessions up to this date (ISO 8601) + agent_type: Filter by agent type (Employee, EinsteinSDR, EinsteinServiceAgent) + channel_type: Filter by channel type + limit: Maximum sessions to import + last_import_timestamp: For incremental sync, only import after this timestamp + + Returns: + Tuple of (list of STRATIX events, ImportResult summary) + """ + result = ImportResult() + all_events: list[dict[str, Any]] = [] + + # Build session query with validated parameters + conditions = [] + if start_date: + self._validate_date(start_date) + conditions.append(f"StartTimestamp >= {start_date}T00:00:00Z") + if end_date: + self._validate_date(end_date) + conditions.append(f"StartTimestamp <= {end_date}T23:59:59Z") + if last_import_timestamp: + self._validate_timestamp(last_import_timestamp) + conditions.append(f"StartTimestamp > {last_import_timestamp}") + + where = f" WHERE {' AND '.join(conditions)}" if conditions else "" + limit_clause = f" LIMIT {limit}" if limit else f" LIMIT {self._batch_size}" + + soql = ( + "SELECT Id, StartTimestamp, EndTimestamp, AiAgentChannelTypeId, " + "AiAgentSessionEndType, VoiceCallId, MessagingSessionId, PreviousSessionId " + f"FROM AIAgentSession{where} ORDER BY StartTimestamp ASC{limit_clause}" + ) + + try: + sessions = self._connection.query(soql) + except Exception as e: + result.errors.append(f"Session query failed: {e}") + return all_events, result + + if not sessions: + return all_events, result + + session_ids = [s["Id"] for s in sessions] + result.sessions_imported = len(sessions) + + # Normalize sessions + for session in sessions: + events = self._normalizer.normalize_session(session) + all_events.extend(events) + + # Import participants + participants = self._query_related( + "AIAgentSessionParticipant", + "AiAgentSessionId", + session_ids, + "Id, AiAgentSessionId, AiAgentTypeId, AiAgentApiName, " + "AiAgentVersionApiName, ParticipantId, AiAgentSessionParticipantRoleId", + result=result, + ) + result.participants_imported = len(participants) + for p in participants: + all_events.append(self._normalizer.normalize_participant(p)) + + # Import interactions + interactions = self._query_related( + "AIAgentInteraction", + "AiAgentSessionId", + session_ids, + "Id, AiAgentSessionId, AiAgentInteractionTypeId, " + "TelemetryTraceId, TelemetryTraceSpanId, TopicApiName, " + "AttributeText, PrevInteractionId", + order_by="Id ASC", + result=result, + ) + result.interactions_imported = len(interactions) + for i in interactions: + all_events.append(self._normalizer.normalize_interaction(i)) + + if interactions: + interaction_ids = [i["Id"] for i in interactions] + + # Import steps + steps = self._query_related( + "AIAgentInteractionStep", + "AiAgentInteractionId", + interaction_ids, + "Id, AiAgentInteractionId, AiAgentInteractionStepTypeId, " + "InputValueText, OutputValueText, ErrorMessageText, " + "GenerationId, GenAiGatewayRequestId, GenAiGatewayResponseId, " + "Name, TelemetryTraceSpanId", + order_by="Id ASC", + result=result, + ) + result.steps_imported = len(steps) + for s in steps: + all_events.append(self._normalizer.normalize_step(s)) + + # Import messages + messages = self._query_related( + "AIAgentInteractionMessage", + "AiAgentInteractionId", + interaction_ids, + "Id, AiAgentInteractionId, AiAgentInteractionMessageTypeId, " + "ContentText, AiAgentInteractionMsgContentTypeId, " + "MessageSentTimestamp, ParentMessageId", + order_by="MessageSentTimestamp ASC", + result=result, + ) + result.messages_imported = len(messages) + for m in messages: + all_events.append(self._normalizer.normalize_message(m)) + + result.events_generated = len(all_events) + logger.info( + "AgentForce import complete: %d sessions, %d events generated", + result.sessions_imported, + result.events_generated, + ) + return all_events, result + + def _query_related( + self, + object_name: str, + foreign_key: str, + parent_ids: list[str], + fields: str, + order_by: str | None = None, + result: ImportResult | None = None, + ) -> list[dict[str, Any]]: + """Query related records in batches to respect SOQL limits.""" + all_records: list[dict[str, Any]] = [] + + # Batch parent IDs to avoid SOQL IN clause limits + for i in range(0, len(parent_ids), self._batch_size): + batch = parent_ids[i : i + self._batch_size] + # Validate IDs to prevent SOQL injection + safe_ids = [self._validate_sf_id(pid) for pid in batch] + ids_str = "', '".join(safe_ids) + soql = f"SELECT {fields} FROM {object_name} WHERE {foreign_key} IN ('{ids_str}')" + if order_by: + soql += f" ORDER BY {order_by}" + + try: + records = self._connection.query(soql) + all_records.extend(records) + except Exception as e: + error_msg = f"Failed to query {object_name}: {e}" + logger.error(error_msg) + if result is not None: + result.errors.append(error_msg) + + return all_records + + @staticmethod + def _validate_date(value: str) -> None: + """Validate an ISO 8601 date string (YYYY-MM-DD).""" + if not _DATE_RE.match(value): + raise ValueError(f"Invalid date format: '{value}'. Expected YYYY-MM-DD.") + + @staticmethod + def _validate_timestamp(value: str) -> None: + """Validate an ISO 8601 timestamp string.""" + if not _TIMESTAMP_RE.match(value): + raise ValueError(f"Invalid timestamp format: '{value}'. Expected ISO 8601.") + + @staticmethod + def _validate_sf_id(value: str) -> str: + """Validate a Salesforce ID format (15 or 18 char alphanumeric). + + Raises: + ValueError: If the value does not match the Salesforce ID format. + """ + if not _SF_ID_RE.match(value): + raise ValueError(f"Invalid Salesforce ID format: {value!r}") + return value diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/llm_eval.py b/src/layerlens/instrument/adapters/frameworks/agentforce/llm_eval.py new file mode 100644 index 00000000..b033a29c --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/llm_eval.py @@ -0,0 +1,437 @@ +""" +Agentforce LLM Evaluation Scenarios + +Provides evaluation capabilities beyond agent tracing: +- Einstein completions evaluation (grading LLM responses) +- Prompt template A/B testing for Agentforce topics +- Model comparison (GPT vs Claude vs Gemini via Atlas) +- CRM outcome ground truth correlation + +These scenarios use imported Agentforce session data as input +and run Stratix graders to produce evaluation scores. +""" + +from __future__ import annotations + +import os +import logging +from typing import Any +from dataclasses import field, dataclass + +from layerlens.instrument.adapters.frameworks.agentforce.models import EvaluationResult + +logger = logging.getLogger(__name__) + + +def _get_stratix_client() -> Any | None: + """Lazily create a Stratix API client from environment variables.""" + api_url = os.environ.get("LAYERLENS_API_URL") + api_key = os.environ.get("LAYERLENS_API_KEY") + if not api_url or not api_key: + return None + try: + from layerlens import Stratix as StratixClient + + return StratixClient(base_url=api_url, api_key=api_key) + except Exception as exc: + logger.debug("Could not create Stratix client: %s", exc) + return None + + +# Default graders for Agentforce evaluation +_DEFAULT_GRADERS = ["relevance", "faithfulness", "coherence", "safety"] + +# Composite score weights (aligned with Section 4.3 of integration doc) +_DEFAULT_WEIGHTS = { + "topic_accuracy": 0.20, + "action_correctness": 0.25, + "response_quality": 0.20, + "safety_compliance": 0.20, + "crm_outcome": 0.15, +} + + +@dataclass +class ABTestResult: + """Result of an A/B test between two prompt variants.""" + + variant_a_scores: dict[str, float] = field(default_factory=dict) + variant_b_scores: dict[str, float] = field(default_factory=dict) + winner: str = "" + significance: float = 0.0 + sample_size: int = 0 + + +@dataclass +class ModelComparisonResult: + """Result of comparing multiple models on the same test cases.""" + + model_scores: dict[str, dict[str, float]] = field(default_factory=dict) + best_model: str = "" + test_cases_evaluated: int = 0 + + +class EinsteinEvaluator: + """ + Evaluate Agentforce LLM responses using Stratix graders. + + Operates on imported session data (from ``AgentForceAdapter.import_sessions()``) + and applies graders to LLM execution steps, action sequences, and agent responses. + + Usage: + evaluator = EinsteinEvaluator(adapter=adapter) + results = evaluator.evaluate_completions( + session_ids=["0Xx..."], + graders=["relevance", "faithfulness"], + ) + """ + + def __init__( + self, + adapter: Any = None, + connection: Any = None, + ) -> None: + """ + Initialize the evaluator. + + Args: + adapter: AgentForceAdapter instance (for session import). + connection: SalesforceConnection (for ground truth queries). + """ + self._adapter = adapter + self._connection = connection + self._client = _get_stratix_client() + + def evaluate_completions( + self, + session_ids: list[str], + graders: list[str] | None = None, + ) -> list[EvaluationResult]: + """ + Evaluate LLM completions from imported Agentforce sessions. + + Extracts LLM execution steps from the session data and runs + the specified graders on each completion. + + Args: + session_ids: Salesforce session IDs to evaluate. + graders: List of grader names (defaults to relevance + faithfulness). + + Returns: + List of EvaluationResult, one per session. + """ + if not session_ids: + return [] + + grader_names = graders or _DEFAULT_GRADERS + results: list[EvaluationResult] = [] + + for session_id in session_ids: + try: + scores = self._evaluate_session(session_id, grader_names) + composite = self._compute_composite_score(scores) + results.append( + EvaluationResult( + session_id=session_id, + scores=scores, + composite_score=composite, + ) + ) + except Exception as e: + logger.warning("Failed to evaluate session %s: %s", session_id, e) + results.append( + EvaluationResult( + session_id=session_id, + errors=[str(e)], + ) + ) + + return results + + def evaluate_topic( + self, + topic: str, + graders: list[str] | None = None, + limit: int = 100, + ) -> list[EvaluationResult]: + """ + Convenience method: import sessions for a topic and evaluate. + + Combines session import + grading in one call. + + Args: + topic: Agentforce topic name to evaluate. + graders: Grader names to run. + limit: Maximum sessions to evaluate. + + Returns: + List of EvaluationResult for the topic. + """ + if not self._adapter: + raise RuntimeError("Adapter required for evaluate_topic()") + + # Import sessions that match the topic + events, result = self._adapter._importer.import_sessions(limit=limit) + + # Extract session IDs from imported events + session_ids: list[str] = [] + for event in events: + payload = event.get("payload", {}) + sid = payload.get("session_id") + if sid and sid not in session_ids: + session_ids.append(sid) + + return self.evaluate_completions(session_ids[:limit], graders) + + def ab_test_prompts( + self, + topic: str, + variant_a: str, + variant_b: str, + test_cases: list[dict[str, str]] | None = None, + graders: list[str] | None = None, + ) -> ABTestResult: + """ + A/B test two prompt variants for an Agentforce topic. + + Args: + topic: The Agentforce topic being tested. + variant_a: First prompt instruction text. + variant_b: Second prompt instruction text. + test_cases: List of test inputs (dicts with "input" key). + graders: Grader names to use for scoring. + + Returns: + ABTestResult with per-variant scores and winner. + """ + grader_names = graders or ["relevance", "trajectory_accuracy"] + cases = test_cases or [] + sample_size = len(cases) + + # Score each variant + a_scores = self._score_variant(variant_a, cases, grader_names) + b_scores = self._score_variant(variant_b, cases, grader_names) + + # Determine winner by average score across graders + a_avg = sum(a_scores.values()) / max(len(a_scores), 1) + b_avg = sum(b_scores.values()) / max(len(b_scores), 1) + winner = "variant_a" if a_avg >= b_avg else "variant_b" + + return ABTestResult( + variant_a_scores=a_scores, + variant_b_scores=b_scores, + winner=winner, + significance=abs(a_avg - b_avg), + sample_size=sample_size, + ) + + def compare_models( + self, + topic: str, + models: list[str], + test_cases: list[dict[str, str]] | None = None, + graders: list[str] | None = None, + ) -> ModelComparisonResult: + """ + Compare multiple LLM models for an Agentforce topic. + + Args: + topic: The Agentforce topic to evaluate. + models: List of model names (e.g., ["gpt-5.3", "claude-opus-4-6"]). + test_cases: Test inputs for evaluation. + graders: Grader names to use. + + Returns: + ModelComparisonResult with per-model scores and best model. + """ + grader_names = graders or _DEFAULT_GRADERS + cases = test_cases or [] + model_scores: dict[str, dict[str, float]] = {} + + for model in models: + scores = self._score_model(model, topic, cases, grader_names) + model_scores[model] = scores + + # Determine best model by highest average score + best_model = "" + best_avg = -1.0 + for model, scores in model_scores.items(): + avg = sum(scores.values()) / max(len(scores), 1) + if avg > best_avg: + best_avg = avg + best_model = model + + return ModelComparisonResult( + model_scores=model_scores, + best_model=best_model, + test_cases_evaluated=len(cases), + ) + + def correlate_outcomes( + self, + session_ids: list[str], + outcome_query: str, + evaluation_dimensions: list[str] | None = None, + ) -> list[EvaluationResult]: + """ + Correlate evaluation scores with CRM business outcomes. + + Args: + session_ids: Session IDs to evaluate and correlate. + outcome_query: SOQL query to fetch business outcomes. + evaluation_dimensions: Grader dimensions to include. + + Returns: + EvaluationResult list with ground_truth populated. + """ + dimensions = evaluation_dimensions or _DEFAULT_GRADERS + + # Evaluate sessions + results = self.evaluate_completions(session_ids, dimensions) + + # Fetch ground truth from Salesforce + if self._connection: + try: + outcomes = self._connection.query(outcome_query) + outcome_map = {r.get("CaseId", r.get("Id", "")): r for r in outcomes} + for result in results: + gt = outcome_map.get(result.session_id, {}) + if gt: + result.ground_truth = gt + except Exception as e: + logger.warning("Failed to fetch ground truth: %s", e) + + return results + + # --- Internal helpers --- + + def _evaluate_session( + self, + session_id: str, + grader_names: list[str], + ) -> dict[str, float]: + """Run graders on a single session. Returns grader->score mapping.""" + if self._client: + try: + result = self._client.evaluations.create( + trace_id=session_id, + grader_ids=grader_names, + ) + # result may be a dict or model; normalise to dict + result_dict = result if isinstance(result, dict) else result.model_dump() + return {g: result_dict.get("scores", {}).get(g, 0.0) for g in grader_names} + except Exception as exc: + logger.warning( + "Grader invocation failed for session %s: %s", + session_id, + exc, + ) + + logger.warning( + "No Stratix client configured — returning 0.0 for session %s. " + "Set LAYERLENS_API_URL and LAYERLENS_API_KEY environment variables.", + session_id, + ) + return dict.fromkeys(grader_names, 0.0) + + def _compute_composite_score( + self, + scores: dict[str, float], + ) -> float | None: + """Compute a weighted composite score from individual grader scores.""" + if not scores: + return None + + total_weight = 0.0 + weighted_sum = 0.0 + + # Map grader names to weight categories + grader_to_category = { + "topic_accuracy": "topic_accuracy", + "tool_correctness": "action_correctness", + "tool_adherence": "action_correctness", + "relevance": "response_quality", + "faithfulness": "response_quality", + "coherence": "response_quality", + "safety": "safety_compliance", + "hallucination": "safety_compliance", + "pii_detection": "safety_compliance", + } + + for grader, score in scores.items(): + category = grader_to_category.get(grader, "response_quality") + weight = _DEFAULT_WEIGHTS.get(category, 0.1) + weighted_sum += score * weight + total_weight += weight + + return weighted_sum / total_weight if total_weight > 0 else None + + def _score_variant( + self, + prompt: str, + test_cases: list[dict[str, str]], + grader_names: list[str], + ) -> dict[str, float]: + """Score a prompt variant across test cases.""" + if not test_cases: + logger.warning("No test cases provided for variant scoring — returning 0.0.") + return dict.fromkeys(grader_names, 0.0) + + if self._client: + try: + aggregated: dict[str, float] = dict.fromkeys(grader_names, 0.0) + for case in test_cases: + result = self._client.evaluations.create( + trace_id=case.get("trace_id", ""), + grader_ids=grader_names, + config={"prompt_override": prompt}, + ) + result_dict = result if isinstance(result, dict) else result.model_dump() + for g in grader_names: + aggregated[g] += result_dict.get("scores", {}).get(g, 0.0) + n = len(test_cases) + return {g: aggregated[g] / n for g in grader_names} + except Exception as exc: + logger.warning("Variant scoring failed: %s", exc) + + logger.warning( + "No Stratix client configured — returning 0.0 for variant scoring. " + "Set LAYERLENS_API_URL and LAYERLENS_API_KEY environment variables." + ) + return dict.fromkeys(grader_names, 0.0) + + def _score_model( + self, + model: str, + topic: str, + test_cases: list[dict[str, str]], + grader_names: list[str], + ) -> dict[str, float]: + """Score a model on test cases.""" + if not test_cases: + logger.warning("No test cases provided for model %s — returning 0.0.", model) + return dict.fromkeys(grader_names, 0.0) + + if self._client: + try: + aggregated: dict[str, float] = dict.fromkeys(grader_names, 0.0) + for case in test_cases: + result = self._client.evaluations.create( + trace_id=case.get("trace_id", ""), + grader_ids=grader_names, + config={"model_override": model}, + ) + result_dict = result if isinstance(result, dict) else result.model_dump() + for g in grader_names: + aggregated[g] += result_dict.get("scores", {}).get(g, 0.0) + n = len(test_cases) + return {g: aggregated[g] / n for g in grader_names} + except Exception as exc: + logger.warning("Model %s scoring failed: %s", model, exc) + + logger.warning( + "No Stratix client configured — returning 0.0 for model %s. " + "Set LAYERLENS_API_URL and LAYERLENS_API_KEY environment variables.", + model, + ) + return dict.fromkeys(grader_names, 0.0) diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/mapper.py b/src/layerlens/instrument/adapters/frameworks/agentforce/mapper.py new file mode 100644 index 00000000..fa555775 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/mapper.py @@ -0,0 +1,251 @@ +""" +Agent API Session to Stratix Trace Event Mapper + +Maps Agent API session data (from ``client.py``) to Stratix canonical +event types. This is distinct from ``normalizer.py`` which handles +Data Cloud DMO records from SOQL queries. + +Mapping: +- Session creation -> agent.state.change (trace_start) +- User message -> agent.input (L1) +- Agent response -> agent.output (L1) +- Topic classification-> environment.config (L4a) +- Action invocation -> tool.call (L5a) +- Guardrail check -> policy.violation (Cross) +- Escalation -> agent.handoff (Cross) +- Session end -> agent.state.change (trace_end) +""" + +from __future__ import annotations + +import time +import logging +from typing import Any + +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + AgentApiMessage, + AgentApiSession, +) + +logger = logging.getLogger(__name__) + + +class AgentApiMapper: + """ + Maps Agent API sessions to Stratix trace events. + + Each public method returns a list of event dicts compatible with + ``BaseAdapter.emit_dict_event(event_type, payload)``. + """ + + def map_session(self, session: AgentApiSession) -> list[dict[str, Any]]: + """ + Map a complete Agent API session to a sequence of Stratix events. + + Args: + session: Complete AgentApiSession with messages. + + Returns: + Ordered list of ``{event_type, payload}`` dicts. + """ + events: list[dict[str, Any]] = [] + + # Session start + events.append(self.map_session_start(session)) + + # Process each message + seen_topics: set[str] = set() + for msg in session.messages: + if msg.role == "user": + events.append(self.map_user_message(msg, session.session_id)) + elif msg.role == "agent": + events.append(self.map_agent_response(msg, session.session_id)) + + # Topic classification (emit once per topic) + if msg.topic and msg.topic not in seen_topics: + events.append( + self.map_topic_classification( + msg.topic, + session.agent_name or "unknown", + session.session_id, + ) + ) + seen_topics.add(msg.topic) + + # Action invocations + for action in msg.actions: + events.append( + self.map_action_invocation( + action, + session.session_id, + ) + ) + + # Guardrail checks + for gr in msg.guardrail_results: + events.append( + self.map_guardrail_check( + gr, + session.session_id, + ) + ) + + # Session end + events.append(self.map_session_end(session)) + + return events + + def map_session_start(self, session: AgentApiSession) -> dict[str, Any]: + """Map session creation to agent.state.change (trace_start).""" + return { + "event_type": "agent.state.change", + "payload": { + "framework": "salesforce_agentforce", + "event_subtype": "trace_start", + "session_id": session.session_id, + "agent_name": session.agent_name, + "timestamp_ns": _ts_to_ns(session.created_at), + }, + } + + def map_session_end(self, session: AgentApiSession) -> dict[str, Any]: + """Map session end to agent.state.change (trace_end).""" + start_ns = _ts_to_ns(session.created_at) + end_ns = _ts_to_ns(session.ended_at) + duration_ns = end_ns - start_ns if start_ns and end_ns else 0 + + return { + "event_type": "agent.state.change", + "payload": { + "framework": "salesforce_agentforce", + "event_subtype": "trace_end", + "session_id": session.session_id, + "agent_name": session.agent_name, + "duration_ns": duration_ns, + "message_count": len(session.messages), + }, + } + + @staticmethod + def map_user_message( + msg: AgentApiMessage, + session_id: str, + ) -> dict[str, Any]: + """Map a user message to agent.input (L1).""" + return { + "event_type": "agent.input", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "content": { + "role": "human", + "message": msg.content, + }, + "timestamp_ns": _ts_to_ns(msg.timestamp), + }, + } + + @staticmethod + def map_agent_response( + msg: AgentApiMessage, + session_id: str, + ) -> dict[str, Any]: + """Map an agent response to agent.output (L1).""" + return { + "event_type": "agent.output", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "content": { + "role": "agent", + "message": msg.content, + }, + "timestamp_ns": _ts_to_ns(msg.timestamp), + }, + } + + @staticmethod + def map_topic_classification( + topic: str, + agent_name: str, + session_id: str, + ) -> dict[str, Any]: + """Map topic classification to environment.config (L4a).""" + return { + "event_type": "environment.config", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "agent_name": agent_name, + "topic": topic, + "config_type": "topic_classification", + }, + } + + @staticmethod + def map_action_invocation( + action: dict[str, Any], + session_id: str, + ) -> dict[str, Any]: + """Map an Agentforce action to tool.call (L5a).""" + return { + "event_type": "tool.call", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "tool_name": action.get("name", "unknown"), + "tool_input": action.get("parameters", {}), + "tool_output": action.get("result"), + "tool_type": "salesforce_action", + }, + } + + @staticmethod + def map_guardrail_check( + guardrail: dict[str, Any], + session_id: str, + ) -> dict[str, Any]: + """Map a guardrail check to policy.violation (Cross-cutting).""" + return { + "event_type": "policy.violation", + "payload": { + "framework": "salesforce_agentforce", + "session_id": session_id, + "guardrail_name": guardrail.get("name", "unknown"), + "triggered": guardrail.get("triggered", False), + "message": guardrail.get("message"), + "source": "einstein_trust_layer", + }, + } + + @staticmethod + def map_escalation( + session_id: str, + from_agent: str, + to_agent: str = "human", + reason: str = "escalation", + ) -> dict[str, Any]: + """Map an escalation to agent.handoff (Cross-cutting).""" + return { + "event_type": "agent.handoff", + "payload": { + "from_agent": from_agent, + "to_agent": to_agent, + "reason": reason, + "framework": "salesforce_agentforce", + "session_id": session_id, + }, + } + + +def _ts_to_ns(ts: str | None) -> int: + """Convert an ISO 8601 timestamp string to nanoseconds since epoch.""" + if not ts: + return time.time_ns() + try: + from datetime import datetime + + dt = datetime.fromisoformat(ts.replace("Z", "+00:00")) + return int(dt.timestamp() * 1_000_000_000) + except (ValueError, TypeError): + return time.time_ns() diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/models.py b/src/layerlens/instrument/adapters/frameworks/agentforce/models.py new file mode 100644 index 00000000..dab4205c --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/models.py @@ -0,0 +1,322 @@ +""" +Pydantic models for Salesforce Agentforce data structures. + +Provides type-safe representations of: +- Salesforce DMO records (AIAgentSession, AIAgentInteractionStep, etc.) +- Agent API request/response payloads +- Platform Event payloads +- Trust Layer configuration +- LLM evaluation inputs/outputs +""" + +from __future__ import annotations + +from enum import Enum # Python 3.11+ has StrEnum; using `(str, Enum)` for 3.9/3.10 compat. +from typing import Any, Optional + +from pydantic import Field, BaseModel + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class AgentChannelType(str, Enum): + """Agentforce session channel types.""" + + WEB = "Web" + MESSAGING = "Messaging" + VOICE = "Voice" + SLACK = "Slack" + API = "Api" + + +class SessionEndType(str, Enum): + """How an Agentforce session ended.""" + + COMPLETED = "Completed" + ESCALATED = "Escalated" + TIMED_OUT = "TimedOut" + ERROR = "Error" + ABANDONED = "Abandoned" + + +class StepType(str, Enum): + """Agentforce interaction step types from DMO.""" + + USER_INPUT = "UserInputStep" + LLM_EXECUTION = "LLMExecutionStep" + FUNCTION = "FunctionStep" + ACTION_INVOCATION = "ActionInvocationStep" + + +class ParticipantType(str, Enum): + """Participant roles in an Agentforce session.""" + + AI = "ai" + HUMAN = "human" + + +class AuthFlow(str, Enum): + """Supported Salesforce authentication flows.""" + + JWT_BEARER = "jwt_bearer" + CLIENT_CREDENTIALS = "client_credentials" + NAMED_CREDENTIAL = "named_credential" + + +class CaptureMode(str, Enum): + """Agentforce capture modes.""" + + POLLING = "polling" + REALTIME = "realtime" + HYBRID = "hybrid" + + +# --------------------------------------------------------------------------- +# DMO Record Models +# --------------------------------------------------------------------------- + + +class AgentSession(BaseModel): + """AIAgentSession DMO record.""" + + id: str = Field(alias="Id", description="Salesforce record ID") + start_timestamp: Optional[str] = Field( + default=None, + alias="StartTimestamp", + description="ISO 8601 session start time", + ) + end_timestamp: Optional[str] = Field( + default=None, + alias="EndTimestamp", + description="ISO 8601 session end time", + ) + channel_type: Optional[str] = Field( + default=None, + alias="AiAgentChannelTypeId", + description="Session channel (Web, Messaging, Voice, etc.)", + ) + session_end_type: Optional[str] = Field( + default=None, + alias="AiAgentSessionEndType", + description="How the session ended", + ) + voice_call_id: Optional[str] = Field(default=None, alias="VoiceCallId") + messaging_session_id: Optional[str] = Field(default=None, alias="MessagingSessionId") + previous_session_id: Optional[str] = Field(default=None, alias="PreviousSessionId") + + model_config = {"populate_by_name": True} + + +class AgentParticipant(BaseModel): + """AIAgentSessionParticipant DMO record.""" + + id: str = Field(alias="Id") + session_id: str = Field(alias="AiAgentSessionId") + agent_type: Optional[str] = Field(default=None, alias="AiAgentTypeId") + agent_api_name: Optional[str] = Field(default=None, alias="AiAgentApiName") + agent_version: Optional[str] = Field(default=None, alias="AiAgentVersionApiName") + participant_id: Optional[str] = Field(default=None, alias="ParticipantId") + role: Optional[str] = Field(default=None, alias="AiAgentSessionParticipantRoleId") + + model_config = {"populate_by_name": True} + + +class AgentInteraction(BaseModel): + """AIAgentInteraction DMO record.""" + + id: str = Field(alias="Id") + session_id: str = Field(alias="AiAgentSessionId") + interaction_type: Optional[str] = Field(default=None, alias="AiAgentInteractionTypeId") + telemetry_trace_id: Optional[str] = Field(default=None, alias="TelemetryTraceId") + telemetry_span_id: Optional[str] = Field(default=None, alias="TelemetryTraceSpanId") + topic_api_name: Optional[str] = Field(default=None, alias="TopicApiName") + attribute_text: Optional[str] = Field(default=None, alias="AttributeText") + prev_interaction_id: Optional[str] = Field(default=None, alias="PrevInteractionId") + + model_config = {"populate_by_name": True} + + +class AgentInteractionStep(BaseModel): + """AIAgentInteractionStep DMO record.""" + + id: str = Field(alias="Id") + interaction_id: str = Field(alias="AiAgentInteractionId") + step_type: Optional[str] = Field(default=None, alias="AiAgentInteractionStepTypeId") + input_value: Optional[str] = Field(default=None, alias="InputValueText") + output_value: Optional[str] = Field(default=None, alias="OutputValueText") + error_message: Optional[str] = Field(default=None, alias="ErrorMessageText") + generation_id: Optional[str] = Field(default=None, alias="GenerationId") + gateway_request_id: Optional[str] = Field(default=None, alias="GenAiGatewayRequestId") + gateway_response_id: Optional[str] = Field(default=None, alias="GenAiGatewayResponseId") + name: Optional[str] = Field(default=None, alias="Name") + telemetry_span_id: Optional[str] = Field(default=None, alias="TelemetryTraceSpanId") + start_timestamp: Optional[str] = Field(default=None, alias="StartTimestamp") + end_timestamp: Optional[str] = Field(default=None, alias="EndTimestamp") + + model_config = {"populate_by_name": True} + + +class AgentInteractionMessage(BaseModel): + """AIAgentInteractionMessage DMO record.""" + + id: str = Field(alias="Id") + interaction_id: str = Field(alias="AiAgentInteractionId") + message_type: Optional[str] = Field(default=None, alias="AiAgentInteractionMessageTypeId") + content_text: Optional[str] = Field(default=None, alias="ContentText") + content_type: Optional[str] = Field(default=None, alias="AiAgentInteractionMsgContentTypeId") + sent_timestamp: Optional[str] = Field(default=None, alias="MessageSentTimestamp") + parent_message_id: Optional[str] = Field(default=None, alias="ParentMessageId") + + model_config = {"populate_by_name": True} + + +# --------------------------------------------------------------------------- +# Agent API Models +# --------------------------------------------------------------------------- + + +class AgentApiMessage(BaseModel): + """A message in an Agent API session.""" + + id: Optional[str] = Field(default=None, description="Message ID") + role: str = Field(description="Message role (user, agent, system)") + content: str = Field(description="Message content text") + timestamp: Optional[str] = Field(default=None, description="ISO 8601 timestamp") + topic: Optional[str] = Field(default=None, description="Classified topic name") + actions: list[dict[str, Any]] = Field( + default_factory=list, + description="Actions taken by the agent", + ) + guardrail_results: list[dict[str, Any]] = Field( + default_factory=list, + description="Trust Layer guardrail check results", + ) + + +class AgentApiSession(BaseModel): + """Represents an Agent API session.""" + + session_id: str = Field(description="Salesforce session ID") + agent_name: Optional[str] = Field(default=None, description="Agentforce agent name") + status: str = Field(default="active", description="Session status") + messages: list[AgentApiMessage] = Field( + default_factory=list, + description="Session messages in order", + ) + created_at: Optional[str] = Field(default=None, description="Session creation timestamp") + ended_at: Optional[str] = Field(default=None, description="Session end timestamp") + + +# --------------------------------------------------------------------------- +# Trust Layer Models +# --------------------------------------------------------------------------- + + +class TrustLayerGuardrail(BaseModel): + """Einstein Trust Layer guardrail configuration.""" + + name: str = Field(description="Guardrail name") + type: str = Field(description="Guardrail type (toxicity, pii, custom)") + enabled: bool = Field(default=True, description="Whether the guardrail is active") + threshold: Optional[float] = Field( + default=None, + description="Detection threshold (0.0-1.0)", + ) + action: str = Field( + default="block", + description="Action on violation (block, warn, log)", + ) + + +class TrustLayerConfig(BaseModel): + """Complete Einstein Trust Layer configuration.""" + + guardrails: list[TrustLayerGuardrail] = Field( + default_factory=list, + description="Configured guardrails", + ) + data_masking_enabled: bool = Field( + default=False, + description="Whether PII/PCI masking is active", + ) + zero_data_retention: bool = Field( + default=True, + description="Whether zero data retention is enabled for LLM calls", + ) + audit_trail_enabled: bool = Field( + default=True, + description="Whether audit trail logging is active", + ) + + +# --------------------------------------------------------------------------- +# Platform Event Models +# --------------------------------------------------------------------------- + + +class AgentSessionEvent(BaseModel): + """Platform Event payload for AgentSession__e.""" + + session_id: str = Field(description="Agentforce session ID") + agent_name: Optional[str] = Field(default=None, description="Agent name") + topic_name: Optional[str] = Field(default=None, description="Classified topic") + actions_taken: Optional[str] = Field( + default=None, + description="JSON-encoded actions list", + ) + response_text: Optional[str] = Field( + default=None, + description="Agent response text", + ) + trust_layer_flags: Optional[str] = Field( + default=None, + description="JSON-encoded Trust Layer results", + ) + replay_id: Optional[str] = Field( + default=None, + description="Platform Event replay ID for redelivery", + ) + + +# --------------------------------------------------------------------------- +# Evaluation Models +# --------------------------------------------------------------------------- + + +class EvaluationRequest(BaseModel): + """Request to evaluate Agentforce sessions.""" + + session_ids: list[str] = Field(description="Salesforce session IDs to evaluate") + graders: list[str] = Field( + default_factory=lambda: ["relevance", "faithfulness"], + description="Grader names to run", + ) + include_ground_truth: bool = Field( + default=False, + description="Whether to fetch CRM outcome ground truth", + ) + ground_truth_query: Optional[str] = Field( + default=None, + description="SOQL query for ground truth data", + ) + + +class EvaluationResult(BaseModel): + """Result of evaluating Agentforce sessions.""" + + session_id: str = Field(description="Evaluated session ID") + scores: dict[str, float] = Field( + default_factory=dict, + description="Grader name -> score mapping", + ) + composite_score: Optional[float] = Field( + default=None, + description="Weighted composite quality score", + ) + ground_truth: dict[str, Any] = Field( + default_factory=dict, + description="CRM outcome data if fetched", + ) + errors: list[str] = Field(default_factory=list, description="Evaluation errors") diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/normalizer.py b/src/layerlens/instrument/adapters/frameworks/agentforce/normalizer.py new file mode 100644 index 00000000..ada9bf61 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/normalizer.py @@ -0,0 +1,251 @@ +""" +AgentForce DMO to STRATIX Event Normalizer + +Maps AgentForce Data Model Objects to STRATIX canonical event types: +- AIAgentSession → agent.lifecycle (start/end) +- AIAgentSessionParticipant → agent.identity +- AIAgentInteraction → agent.input / agent.output +- AIAgentInteractionStep (UserInputStep) → agent.input (L1) +- AIAgentInteractionStep (LLMExecutionStep) → model.invoke (L3) +- AIAgentInteractionStep (FunctionStep / ActionInvocationStep) → tool.call (L5) +- AIAgentInteractionMessage (Input) → agent.input +- AIAgentInteractionMessage (Output) → agent.output +""" + +from __future__ import annotations + +import json +import logging +from typing import Any +from datetime import datetime + +logger = logging.getLogger(__name__) + +# Step type to STRATIX event type mapping +_STEP_TYPE_MAP = { + "UserInputStep": "agent.input", + "LLMExecutionStep": "model.invoke", + "FunctionStep": "tool.call", + "ActionInvocationStep": "tool.call", +} + + +class AgentForceNormalizer: + """Normalize AgentForce DMO records to STRATIX events.""" + + def normalize_session( + self, + session: dict[str, Any], + ) -> list[dict[str, Any]]: + """Normalize an AIAgentSession to agent.lifecycle start/end events.""" + events = [] + + sf_meta = { + "sf.session.id": session.get("Id"), + "sf.session.channel": session.get("AiAgentChannelTypeId"), + "sf.session.end_type": session.get("AiAgentSessionEndType"), + } + + # Start event + events.append( + { + "event_type": "agent.lifecycle", + "payload": { + "lifecycle_action": "start", + "session_id": session.get("Id"), + "channel_type": session.get("AiAgentChannelTypeId"), + "previous_session_id": session.get("PreviousSessionId"), + "voice_call_id": session.get("VoiceCallId"), + "messaging_session_id": session.get("MessagingSessionId"), + }, + "metadata": sf_meta, + "timestamp": session.get("StartTimestamp"), + } + ) + + # End event (if session has ended) + end_ts = session.get("EndTimestamp") + if end_ts: + events.append( + { + "event_type": "agent.lifecycle", + "payload": { + "lifecycle_action": "end", + "session_id": session.get("Id"), + "session_end_type": session.get("AiAgentSessionEndType"), + "channel_type": session.get("AiAgentChannelTypeId"), + }, + "metadata": sf_meta, + "timestamp": end_ts, + } + ) + + return events + + def normalize_participant( + self, + participant: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentSessionParticipant to agent identity metadata.""" + agent_type = participant.get("AiAgentTypeId", "") + is_human = agent_type == "Employee" + + return { + "event_type": "agent.identity", + "payload": { + "participant_type": "human" if is_human else "ai", + "agent_type": agent_type, + "agent_api_name": participant.get("AiAgentApiName"), + "agent_version": participant.get("AiAgentVersionApiName"), + "participant_id": participant.get("ParticipantId"), + "role": participant.get("AiAgentSessionParticipantRoleId"), + "session_id": participant.get("AiAgentSessionId"), + }, + } + + def normalize_interaction( + self, + interaction: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentInteraction to a trace span.""" + # Parse AttributeText as JSON if present + attr_text = interaction.get("AttributeText") + attributes = {} + if attr_text: + try: + attributes = json.loads(attr_text) + except (json.JSONDecodeError, TypeError): + attributes = {"raw": attr_text} + + return { + "event_type": "agent.interaction", + "identity": { + "trace_id": interaction.get("TelemetryTraceId"), + "span_id": interaction.get("TelemetryTraceSpanId"), + }, + "payload": { + "interaction_id": interaction.get("Id"), + "interaction_type": interaction.get("AiAgentInteractionTypeId"), + "topic": interaction.get("TopicApiName"), + "attributes": attributes, + "prev_interaction_id": interaction.get("PrevInteractionId"), + "session_id": interaction.get("AiAgentSessionId"), + }, + "metadata": { + "sf.topic.name": interaction.get("TopicApiName"), + "sf.session.id": interaction.get("AiAgentSessionId"), + }, + } + + def normalize_step( + self, + step: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentInteractionStep to the appropriate STRATIX event.""" + step_type = step.get("AiAgentInteractionStepTypeId", "") + event_type = _STEP_TYPE_MAP.get(step_type, "tool.call") + + base: dict[str, Any] = { + "event_type": event_type, + "identity": { + "span_id": step.get("TelemetryTraceSpanId"), + }, + } + + # Salesforce metadata passthrough + base["metadata"] = { + "sf.step.name": step.get("Name"), + "sf.step.id": step.get("Id"), + "sf.generation.id": step.get("GenerationId"), + } + + # Extract timing if available + start_ts = step.get("StartTimestamp") + end_ts = step.get("EndTimestamp") + if start_ts: + base["timestamp"] = start_ts + if start_ts and end_ts: + try: + start_dt = datetime.fromisoformat(str(start_ts).replace("Z", "+00:00")) + end_dt = datetime.fromisoformat(str(end_ts).replace("Z", "+00:00")) + base["duration_ms"] = (end_dt - start_dt).total_seconds() * 1000 + except (ValueError, TypeError): + pass + + if event_type == "model.invoke": + base["payload"] = { + "model": { + "provider": "salesforce", + "name": step.get("Name", "unknown"), + "version": "unavailable", + "parameters": {}, + }, + "input_messages": [{"role": "user", "content": step.get("InputValueText", "")}], + "output_message": {"role": "assistant", "content": step.get("OutputValueText", "")}, + "error": step.get("ErrorMessageText"), + "metadata": { + "generation_id": step.get("GenerationId"), + "gateway_request_id": step.get("GenAiGatewayRequestId"), + "gateway_response_id": step.get("GenAiGatewayResponseId"), + }, + } + + elif event_type == "tool.call": + input_text = step.get("InputValueText", "") + output_text = step.get("OutputValueText") + + base["payload"] = { + "tool": { + "name": step.get("Name", "unknown"), + "version": "unavailable", + "integration": "salesforce_agentforce", + }, + "input": _try_parse_json(input_text), + "output": _try_parse_json(output_text) if output_text else None, + "error": step.get("ErrorMessageText"), + } + + else: # agent.input + base["payload"] = { + "content": { + "role": "human", + "message": step.get("InputValueText", ""), + }, + } + + return base + + def normalize_message( + self, + message: dict[str, Any], + ) -> dict[str, Any]: + """Normalize an AIAgentInteractionMessage to agent.input or agent.output.""" + msg_type = message.get("AiAgentInteractionMessageTypeId", "") + event_type = "agent.output" if msg_type == "Output" else "agent.input" + role = "agent" if msg_type == "Output" else "human" + + return { + "event_type": event_type, + "payload": { + "content": { + "role": role, + "message": message.get("ContentText", ""), + "metadata": { + "content_type": message.get("AiAgentInteractionMsgContentTypeId"), + "parent_message_id": message.get("ParentMessageId"), + }, + }, + }, + "timestamp": message.get("MessageSentTimestamp"), + } + + +def _try_parse_json(text: str) -> dict[str, Any]: + """Try to parse text as JSON, falling back to raw string wrapper.""" + if not text: + return {} + try: + result = json.loads(text) + return result if isinstance(result, dict) else {"raw": text} + except (json.JSONDecodeError, TypeError): + return {"raw": text} diff --git a/src/layerlens/instrument/adapters/frameworks/agentforce/trust_layer.py b/src/layerlens/instrument/adapters/frameworks/agentforce/trust_layer.py new file mode 100644 index 00000000..d432d4a9 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/agentforce/trust_layer.py @@ -0,0 +1,194 @@ +""" +Einstein Trust Layer Policy Importer + +Imports Einstein Trust Layer guardrail configuration from Salesforce +and converts it to Stratix policy-as-code YAML format. + +Supports: +- Guardrail rules extraction via Metadata API +- Data masking policy import +- Conversion to Stratix YAML policy DSL +- Round-trip: import, evaluate, export updated policies + +Reference: https://developer.salesforce.com/docs/einstein/genai/guide/trust.html +""" + +from __future__ import annotations + +import logging + +from layerlens.instrument.adapters.frameworks.agentforce.auth import SalesforceConnection +from layerlens.instrument.adapters.frameworks.agentforce.models import ( + TrustLayerConfig, + TrustLayerGuardrail, +) + +logger = logging.getLogger(__name__) + +# Metadata API types for Trust Layer config +_TRUST_LAYER_METADATA_TYPES = [ + "GenAiPlugin", + "GenAiFunction", +] + +# Default guardrail names in Einstein Trust Layer +_DEFAULT_GUARDRAILS = [ + "toxicity_detection", + "pii_detection", + "prompt_injection", + "hallucination_detection", +] + + +class TrustLayerImporter: + """ + Import Einstein Trust Layer configuration and convert to Stratix policy. + + Usage: + importer = TrustLayerImporter(connection=connection) + config = importer.fetch_config() + yaml_str = importer.to_stratix_policy(config) + """ + + def __init__(self, connection: SalesforceConnection) -> None: + self._connection = connection + + def fetch_config(self) -> TrustLayerConfig: + """ + Fetch Einstein Trust Layer configuration from Salesforce. + + Queries the Setup metadata to extract guardrail rules, + data masking settings, and audit configuration. + + Returns: + TrustLayerConfig with all detected guardrails and settings. + """ + guardrails: list[TrustLayerGuardrail] = [] + + # Query for GenAI guardrail metadata + try: + records = self._connection.query( + "SELECT DeveloperName, Language, MasterLabel " + "FROM GenAiPlugin " + "WHERE IsDeleted = false " + "ORDER BY DeveloperName ASC" + ) + for record in records: + name = record.get("DeveloperName", "") + if name: + guardrails.append( + TrustLayerGuardrail( + name=name, + type=self._classify_guardrail(name), + enabled=True, + ) + ) + except Exception as e: + logger.warning("Failed to query GenAiPlugin metadata: %s", e) + + # Add default guardrails if none found (Trust Layer has built-in ones) + if not guardrails: + for name in _DEFAULT_GUARDRAILS: + guardrails.append( + TrustLayerGuardrail( + name=name, + type=self._classify_guardrail(name), + enabled=True, + ) + ) + + return TrustLayerConfig( + guardrails=guardrails, + data_masking_enabled=False, # Disabled for agents per SF docs + zero_data_retention=True, + audit_trail_enabled=True, + ) + + def to_stratix_policy( + self, + config: TrustLayerConfig, + policy_name: str = "agentforce_trust_layer", + policy_version: str = "1.0.0", + ) -> str: + """ + Convert a TrustLayerConfig to Stratix policy-as-code YAML. + + Args: + config: The Trust Layer configuration to convert. + policy_name: Name for the generated policy. + policy_version: Version string for the policy. + + Returns: + YAML string representing a Stratix policy document. + """ + rules: list[str] = [] + + for guardrail in config.guardrails: + action = "block" if guardrail.action == "block" else "warn" + threshold = guardrail.threshold if guardrail.threshold is not None else 0.8 + + rule_yaml = ( + f" - name: {guardrail.name}\n" + f' description: "Imported from Einstein Trust Layer: {guardrail.type}"\n' + f" type: {guardrail.type}\n" + f" enabled: {str(guardrail.enabled).lower()}\n" + f" threshold: {threshold}\n" + f" action: {action}\n" + f" source: einstein_trust_layer" + ) + rules.append(rule_yaml) + + rules_block = "\n".join(rules) if rules else " []" + + yaml_output = ( + f"# Stratix Policy - Imported from Einstein Trust Layer\n" + f"# Generated by: stratix.sdk.python.adapters.agentforce.trust_layer\n" + f"# Source: Salesforce Einstein Trust Layer\n" + f"\n" + f"policy:\n" + f" name: {policy_name}\n" + f' version: "{policy_version}"\n' + f' description: "Policy imported from Salesforce Einstein Trust Layer"\n' + f" source: salesforce_agentforce\n" + f"\n" + f"settings:\n" + f" data_masking: {str(config.data_masking_enabled).lower()}\n" + f" zero_data_retention: {str(config.zero_data_retention).lower()}\n" + f" audit_trail: {str(config.audit_trail_enabled).lower()}\n" + f"\n" + f"rules:\n" + f"{rules_block}\n" + ) + + return yaml_output + + def import_and_convert( + self, + policy_name: str = "agentforce_trust_layer", + ) -> tuple[TrustLayerConfig, str]: + """ + Convenience method: fetch config and convert to YAML in one call. + + Args: + policy_name: Name for the generated policy. + + Returns: + Tuple of (TrustLayerConfig, YAML string). + """ + config = self.fetch_config() + yaml_str = self.to_stratix_policy(config, policy_name=policy_name) + return config, yaml_str + + @staticmethod + def _classify_guardrail(name: str) -> str: + """Classify a guardrail name into a guardrail type.""" + name_lower = name.lower() + if "toxic" in name_lower or "harm" in name_lower: + return "toxicity" + if "pii" in name_lower or "mask" in name_lower or "privacy" in name_lower: + return "pii" + if "injection" in name_lower or "jailbreak" in name_lower: + return "prompt_injection" + if "hallucin" in name_lower or "ground" in name_lower: + return "hallucination" + return "custom" diff --git a/src/layerlens/instrument/adapters/frameworks/autogen/__init__.py b/src/layerlens/instrument/adapters/frameworks/autogen/__init__.py new file mode 100644 index 00000000..7c8961be --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/autogen/__init__.py @@ -0,0 +1,56 @@ +""" +STRATIX AutoGen Adapter + +Integrates STRATIX tracing with the Microsoft AutoGen framework. + +Usage: + from layerlens.instrument.adapters.frameworks.autogen import ( + AutoGenAdapter, + instrument_agents, + GroupChatTracer, + HumanProxyTracer, + ) + + adapter = AutoGenAdapter(stratix=stratix_instance) + adapter.connect() + adapter.connect_agents(agent1, agent2) +""" + +from __future__ import annotations + +from typing import Any + +from layerlens.instrument.adapters.frameworks.autogen.groupchat import GroupChatTracer +from layerlens.instrument.adapters.frameworks.autogen.lifecycle import AutoGenAdapter +from layerlens.instrument.adapters.frameworks.autogen.human_proxy import HumanProxyTracer + +# Registry lazy-loading convention +ADAPTER_CLASS = AutoGenAdapter + + +def instrument_agents( + *agents: Any, stratix: Any = None, capture_config: dict[str, Any] | None = None +) -> Any: + """ + Convenience function to instrument AutoGen agents with STRATIX tracing. + + Args: + *agents: AutoGen ConversableAgent instances + stratix: STRATIX SDK instance + capture_config: CaptureConfig to use + + Returns: + List of instrumented agents + """ + adapter = AutoGenAdapter(stratix=stratix, capture_config=capture_config) # type: ignore[arg-type] + adapter.connect() + return adapter.connect_agents(*agents) + + +__all__ = [ + "AutoGenAdapter", + "GroupChatTracer", + "HumanProxyTracer", + "instrument_agents", + "ADAPTER_CLASS", +] diff --git a/src/layerlens/instrument/adapters/frameworks/autogen/groupchat.py b/src/layerlens/instrument/adapters/frameworks/autogen/groupchat.py new file mode 100644 index 00000000..31519415 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/autogen/groupchat.py @@ -0,0 +1,173 @@ +""" +AutoGen GroupChat Tracing + +Traces GroupChat speaker selection and turn management for multi-agent +conversations. +""" + +from __future__ import annotations + +import time +import logging +import threading +from typing import TYPE_CHECKING, Any +from collections.abc import Callable + +if TYPE_CHECKING: + from layerlens.instrument.adapters.frameworks.autogen.lifecycle import AutoGenAdapter + +logger = logging.getLogger(__name__) + + +class GroupChatTracer: + """ + Traces GroupChat speaker selection and turn management. + + Wraps GroupChatManager to intercept speaker selection, message routing, + and termination detection. + """ + + def __init__(self, adapter: AutoGenAdapter) -> None: + self._adapter = adapter + self._lock = threading.Lock() + self._message_seq: int = 0 + self._original_run_chat: Callable[..., Any] | None = None + + @property + def message_seq(self) -> int: + return self._message_seq + + def wrap_manager(self, manager: Any) -> Any: + """ + Wrap a GroupChatManager with tracing. + + Args: + manager: An AutoGen GroupChatManager instance + + Returns: + The wrapped manager (same object, modified in-place) + """ + if hasattr(manager, "run_chat"): + self._original_run_chat = manager.run_chat + manager.run_chat = self._create_traced_run_chat(manager, manager.run_chat) + manager._stratix_tracer = self + return manager + + def on_speaker_selected( + self, + method: str | None = None, + candidates: list[str] | None = None, + chosen: str | None = None, + ) -> None: + """ + Record a speaker selection event. + + Emits agent.code (L2) dict event for the selection. + """ + try: + self._adapter.emit_dict_event( + "agent.code", + { + "framework": "autogen", + "event_subtype": "speaker_selection", + "method": method, + "candidates": candidates, + "chosen": chosen, + "message_seq": self._message_seq, + }, + ) + except Exception: + logger.warning("Error emitting speaker selection event", exc_info=True) + + def on_message_routed( + self, + from_agent: str, + to_agent: str, + message: Any = None, + ) -> None: + """ + Record a message routing event. + + Emits agent.handoff (cross-cutting). + """ + with self._lock: + self._message_seq += 1 + msg_seq = self._message_seq + try: + self._adapter.emit_dict_event( + "agent.handoff", + { + "framework": "autogen", + "from_agent": from_agent, + "to_agent": to_agent, + "reason": "groupchat_routing", + "message_seq": msg_seq, + }, + ) + except Exception: + logger.warning("Error emitting message routing event", exc_info=True) + + def on_termination( + self, + reason: str | None = None, + final_speaker: str | None = None, + ) -> None: + """ + Record conversation termination. + + Emits agent.output (L1). + """ + try: + self._adapter.emit_dict_event( + "agent.output", + { + "framework": "autogen", + "event_subtype": "groupchat_termination", + "termination_reason": reason, + "final_speaker": final_speaker, + "total_messages": self._message_seq, + }, + ) + except Exception: + logger.warning("Error emitting termination event", exc_info=True) + + def _create_traced_run_chat( + self, + manager: Any, + original: Callable[..., Any], + ) -> Callable[..., Any]: + """Create a traced version of run_chat.""" + tracer = self + + def traced_run_chat(*args: Any, **kwargs: Any) -> Any: + start_ns = time.time_ns() + + try: + tracer._adapter.emit_dict_event( + "agent.input", + { + "framework": "autogen", + "event_subtype": "groupchat_start", + "timestamp_ns": start_ns, + }, + ) + except Exception: + logger.warning("Error emitting groupchat start", exc_info=True) + + result = original(*args, **kwargs) + + try: + # Source had an orphan duration calc here with no side + # effect. Removed during port — termination signal is the + # only thing the original code dispatched. + tracer.on_termination( + reason="run_chat_complete", + final_speaker=None, + ) + except Exception: + logger.warning("Error emitting groupchat end", exc_info=True) + + return result + + traced_run_chat._layerlens_original = original # type: ignore[attr-defined] + return traced_run_chat diff --git a/src/layerlens/instrument/adapters/frameworks/autogen/human_proxy.py b/src/layerlens/instrument/adapters/frameworks/autogen/human_proxy.py new file mode 100644 index 00000000..cb976a72 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/autogen/human_proxy.py @@ -0,0 +1,141 @@ +""" +AutoGen Human-in-the-Loop Tracing + +Traces human interactions through UserProxyAgent, capturing requests, +responses, and approval patterns. +""" + +from __future__ import annotations + +import time +import logging +import threading +from typing import TYPE_CHECKING, Any +from collections.abc import Callable + +if TYPE_CHECKING: + from layerlens.instrument.adapters.frameworks.autogen.lifecycle import AutoGenAdapter + +logger = logging.getLogger(__name__) + + +class HumanProxyTracer: + """ + Traces human interactions through UserProxyAgent. + + Wraps get_human_input() to capture human requests and responses, + measure response latency, and detect approval patterns. + """ + + def __init__(self, adapter: AutoGenAdapter) -> None: + self._adapter = adapter + self._lock = threading.Lock() + self._original_get_human_input: Callable[..., Any] | None = None + self._interaction_count: int = 0 + + @property + def interaction_count(self) -> int: + return self._interaction_count + + def wrap_agent(self, agent: Any) -> Any: + """ + Wrap a UserProxyAgent with human interaction tracing. + + Args: + agent: An AutoGen UserProxyAgent instance + + Returns: + The wrapped agent (same object, modified in-place) + """ + if hasattr(agent, "get_human_input"): + self._original_get_human_input = agent.get_human_input + agent.get_human_input = self._create_traced_get_human_input( + agent, agent.get_human_input + ) + agent._stratix_human_tracer = self + return agent + + def _create_traced_get_human_input( + self, + agent: Any, + original: Callable[..., Any], + ) -> Callable[..., Any]: + """Create a traced version of get_human_input.""" + tracer = self + + def traced_get_human_input(prompt: str = "", **kwargs: Any) -> str: + start_ns = time.time_ns() + with tracer._lock: + tracer._interaction_count += 1 + interaction_seq = tracer._interaction_count + + # Emit request event + try: + agent_name = getattr(agent, "name", str(agent)) + tracer._adapter.emit_dict_event( + "agent.input", + { + "framework": "autogen", + "role": "HUMAN", + "input_type": "human_input_request", + "agent": agent_name, + "prompt": prompt[:500] if prompt else "", + "interaction_seq": interaction_seq, + }, + ) + except Exception: + logger.warning("Error emitting human input request", exc_info=True) + + # Call original + response = original(prompt, **kwargs) + + # Emit response event + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + input_type = tracer._classify_input(response) + tracer._adapter.emit_dict_event( + "agent.input", + { + "framework": "autogen", + "role": "HUMAN", + "input_type": input_type, + "agent": agent_name, + "response_preview": response[:500] if response else "", + "response_latency_ms": elapsed_ms, + "interaction_seq": interaction_seq, + }, + ) + except Exception: + logger.warning("Error emitting human input response", exc_info=True) + + return response # type: ignore[no-any-return] + + traced_get_human_input._layerlens_original = original # type: ignore[attr-defined] + return traced_get_human_input + + def _classify_input(self, response: str) -> str: + """ + Classify the type of human input. + + Returns: + Input type string: "approval", "rejection", "auto_reply", + "custom_input", or "empty" + """ + if not response: + return "empty" + + lower = response.strip().lower() + + # Auto-reply detection + if lower in ("", "exit"): + return "auto_reply" + + # Approval patterns + if lower in ("y", "yes", "approve", "ok", "okay", "sure", "proceed", "continue"): + return "approval" + + # Rejection patterns + if lower in ("n", "no", "reject", "deny", "stop", "cancel", "abort"): + return "rejection" + + return "custom_input" diff --git a/src/layerlens/instrument/adapters/frameworks/autogen/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/autogen/lifecycle.py new file mode 100644 index 00000000..50794d29 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/autogen/lifecycle.py @@ -0,0 +1,603 @@ +""" +STRATIX AutoGen Lifecycle Hooks + +Provides the main AutoGenAdapter class with monkey-patch-based instrumentation +for AutoGen ConversableAgent instances. +""" + +from __future__ import annotations + +import time +import uuid +import logging +import threading +from typing import Any + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters.frameworks.autogen.metadata import AutoGenAgentMetadataExtractor + +logger = logging.getLogger(__name__) + + +class AutoGenAdapter(BaseAdapter): + """ + Main adapter for integrating STRATIX with Microsoft AutoGen. + + Uses monkey-patching to intercept ConversableAgent methods (send, receive, + generate_reply, execute_code_blocks) and emit STRATIX telemetry events. + + Supports both new-style (stratix, capture_config) and legacy (stratix_instance) + constructor parameters. + + Usage: + adapter = AutoGenAdapter(stratix=stratix_instance) + adapter.connect() + adapter.connect_agents(agent1, agent2) + agent1.initiate_chat(agent2, message="Hello") + """ + + FRAMEWORK = "autogen" + VERSION = "0.1.0" + # The adapter source files import nothing from ``pydantic`` directly + # (verified by grep across ``frameworks/autogen/``). pyautogen 0.2.x + # supports both Pydantic majors; the adapter only monkey-patches + # ConversableAgent methods and emits dict events, never touching the + # framework's Pydantic models. + requires_pydantic = PydanticCompat.V1_OR_V2 + + def __init__( + self, + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + # Legacy param + stratix_instance: Any | None = None, + memory_service: Any | None = None, + ) -> None: + resolved_stratix = stratix or stratix_instance + super().__init__(stratix=resolved_stratix, capture_config=capture_config) + + self._metadata_extractor = AutoGenAgentMetadataExtractor() + self._adapter_lock = threading.Lock() + self._seen_agents: set[str] = set() + self._wrapped_agents: list[Any] = [] + self._originals: dict[int, dict[str, Any]] = {} # agent id -> original methods + self._message_seq: int = 0 + self._conversation_start_ns: int = 0 + self._framework_version: str | None = None + self._memory_service = memory_service + + # --- BaseAdapter lifecycle --- + + def connect(self) -> None: + """Verify AutoGen is importable and mark as connected.""" + try: + import autogen # type: ignore[import-not-found,unused-ignore] # noqa: F401 + + version = getattr(autogen, "__version__", "unknown") + logger.debug("AutoGen %s detected", version) + except ImportError: + logger.debug("AutoGen not installed; adapter usable in mock/test mode") + self._framework_version = self._detect_framework_version() + self._connected = True + self._status = AdapterStatus.HEALTHY + + def disconnect(self) -> None: + """Unwrap agents and disconnect.""" + for agent in self._wrapped_agents: + self._unwrap_agent(agent) + self._wrapped_agents.clear() + self._originals.clear() + self._connected = False + self._status = AdapterStatus.DISCONNECTED + + def health_check(self) -> AdapterHealth: + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + framework_version=self._framework_version, + adapter_version=self.VERSION, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name="AutoGenAdapter", + version=self.VERSION, + framework=self.FRAMEWORK, + framework_version=self._framework_version, + capabilities=[ + AdapterCapability.TRACE_TOOLS, + AdapterCapability.TRACE_MODELS, + AdapterCapability.TRACE_STATE, + AdapterCapability.TRACE_HANDOFFS, + ], + description="LayerLens adapter for Microsoft AutoGen framework", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + return ReplayableTrace( + adapter_name="AutoGenAdapter", + framework=self.FRAMEWORK, + trace_id=str(uuid.uuid4()), + events=list(self._trace_events), + state_snapshots=[], + config={ + "capture_config": self._capture_config.model_dump(), + }, + ) + + # --- Agent wrapping --- + + def connect_agents(self, *agents: Any) -> list[Any]: + """ + Monkey-patch AutoGen agents with STRATIX tracing. + + Wraps send, receive, generate_reply, and execute_code_blocks methods. + Stores originals for unwrap on disconnect. + + Emits environment.config (L4a) on first encounter per agent. + + Args: + *agents: AutoGen ConversableAgent instances + + Returns: + List of wrapped agents (same objects, modified in-place) + """ + from layerlens.instrument.adapters.frameworks.autogen.wrappers import ( + create_traced_send, + create_traced_receive, + create_traced_execute_code, + create_traced_generate_reply, + ) + + result = [] + for agent in agents: + agent_id = id(agent) + if agent_id in self._originals: + result.append(agent) + continue + + originals: dict[str, Any] = {} + + # Wrap send + if hasattr(agent, "send"): + originals["send"] = agent.send + agent.send = create_traced_send(self, agent, agent.send) + + # Wrap receive + if hasattr(agent, "receive"): + originals["receive"] = agent.receive + agent.receive = create_traced_receive(self, agent, agent.receive) + + # Wrap generate_reply + if hasattr(agent, "generate_reply"): + originals["generate_reply"] = agent.generate_reply + agent.generate_reply = create_traced_generate_reply( + self, agent, agent.generate_reply + ) + + # Wrap execute_code_blocks + if hasattr(agent, "execute_code_blocks"): + originals["execute_code_blocks"] = agent.execute_code_blocks + agent.execute_code_blocks = create_traced_execute_code( + self, agent, agent.execute_code_blocks + ) + + self._originals[agent_id] = originals + self._wrapped_agents.append(agent) + + # Emit agent config on first encounter + self._emit_agent_config(agent) + + result.append(agent) + + return result + + def _unwrap_agent(self, agent: Any) -> None: + """Restore original methods on an agent.""" + agent_id = id(agent) + originals = self._originals.get(agent_id) + if not originals: + return + for method_name, original in originals.items(): + try: + setattr(agent, method_name, original) + except Exception: + logger.debug("Could not unwrap %s on agent", method_name, exc_info=True) + + # --- Lifecycle hooks (called by wrappers) --- + + def on_send( + self, + sender: Any, + message: Any, + recipient: Any, + ) -> None: + """ + Handle agent send. + + Emits agent.handoff (cross-cutting). + """ + with self._adapter_lock: + self._message_seq += 1 + msg_seq = self._message_seq + sender_name = getattr(sender, "name", str(sender)) + recipient_name = getattr(recipient, "name", str(recipient)) + + self.emit_dict_event( + "agent.handoff", + { + "framework": "autogen", + "from_agent": sender_name, + "to_agent": recipient_name, + "message_preview": self._truncate(self._message_content(message)), + "message_seq": msg_seq, + }, + ) + + def on_receive( + self, + receiver: Any, + message: Any, + sender: Any, + ) -> None: + """ + Handle agent receive. + + Emits agent.state.change (cross-cutting). + """ + receiver_name = getattr(receiver, "name", str(receiver)) + sender_name = getattr(sender, "name", str(sender)) if sender else None + + self.emit_dict_event( + "agent.state.change", + { + "framework": "autogen", + "agent": receiver_name, + "event_subtype": "message_received", + "from_agent": sender_name, + "message_preview": self._truncate(self._message_content(message)), + }, + ) + + def on_generate_reply( + self, + agent: Any, + messages: Any = None, + reply: Any = None, + latency_ms: float | None = None, + ) -> None: + """ + Handle reply generation. + + Emits model.invoke (L3). + """ + agent_name = getattr(agent, "name", str(agent)) + model = self._extract_model_name(agent) + + payload: dict[str, Any] = { + "framework": "autogen", + "agent": agent_name, + "model": model, + "reply_preview": self._truncate(self._message_content(reply)), + } + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + # Extract token counts if available + token_usage = self._extract_token_usage_from_reply(reply) + if token_usage: + payload.update(token_usage) + + # Include messages for Prompt Lab extraction (gated by capture_content) + if self._capture_config.capture_content and messages: + normalized: list[dict[str, str]] = [] + # Prepend system message from agent config + sys_msg = self._extract_system_message(agent) + if sys_msg: + normalized.append({"role": "system", "content": self._truncate(sys_msg, 10_000)}) + if isinstance(messages, list): + for msg in messages: + if isinstance(msg, dict) and "role" in msg and "content" in msg: + normalized.append( + { + "role": str(msg["role"]), + "content": str(msg["content"])[:10_000], + } + ) + elif isinstance(msg, str): + normalized.append({"role": "user", "content": msg[:10_000]}) + if normalized: + payload["messages"] = normalized + + self.emit_dict_event("model.invoke", payload) + + def on_execute_code( + self, + agent: Any, + code_blocks: Any = None, + result: Any = None, + latency_ms: float | None = None, + ) -> None: + """ + Handle code execution. + + Emits tool.call (L5a) and tool.environment (L5c). + """ + agent_name = getattr(agent, "name", str(agent)) + + # tool.call for the code execution + self.emit_dict_event( + "tool.call", + { + "framework": "autogen", + "tool_name": "code_execution", + "agent": agent_name, + "code_blocks_count": len(code_blocks) if code_blocks else 0, + "result_preview": self._truncate(str(result)) if result else None, + "latency_ms": latency_ms, + }, + ) + + # tool.environment for execution environment details + self.emit_dict_event( + "tool.environment", + { + "framework": "autogen", + "agent": agent_name, + "execution_type": "code_block", + "code_blocks_count": len(code_blocks) if code_blocks else 0, + }, + ) + + def on_conversation_start( + self, + initiator: Any, + message: Any, + ) -> None: + """ + Handle conversation start. + + Emits agent.input (L1). + """ + with self._adapter_lock: + self._conversation_start_ns = time.time_ns() + initiator_name = getattr(initiator, "name", str(initiator)) + + self.emit_dict_event( + "agent.input", + { + "framework": "autogen", + "initiator": initiator_name, + "message": self._safe_serialize(message), + "timestamp_ns": self._conversation_start_ns, + }, + ) + + def on_conversation_end( + self, + final_message: Any = None, + termination_reason: str | None = None, + ) -> None: + """ + Handle conversation end. + + Emits agent.output (L1). + """ + end_ns = time.time_ns() + duration_ns = end_ns - self._conversation_start_ns if self._conversation_start_ns else 0 + + self.emit_dict_event( + "agent.output", + { + "framework": "autogen", + "final_message": self._safe_serialize(final_message), + "termination_reason": termination_reason, + "duration_ns": duration_ns, + }, + ) + + # --- Memory integration --- + + def on_message( + self, + agent_id: str, + message: Any, + ) -> None: + """Store an agent message as episodic memory. + + Only active when ``memory_service`` is provided. Failures are + logged and swallowed to avoid disrupting normal operation. + + Args: + agent_id: Identifier of the agent that sent/received the message. + message: The message content to store. + """ + if self._memory_service is None: + return + + try: + from layerlens.instrument._vendored.memory_models import MemoryEntry + + content = self._message_content(message) + timestamp = int(time.time()) + entry = MemoryEntry( + org_id=getattr(self._stratix, "org_id", ""), + agent_id=agent_id, + memory_type="episodic", + key=f"message_{timestamp}", + content=content[:4000], + importance=0.4, + metadata={"source": "autogen_adapter"}, + ) + self._memory_service.store(entry) + except Exception: + logger.debug( + "Failed to store episodic memory for agent %s", + agent_id, + exc_info=True, + ) + + def on_conversation_end_memory( + self, + agent_id: str, + summary: str, + ) -> None: + """Consolidate a conversation summary into semantic memory. + + Only active when ``memory_service`` is provided. Failures are + logged and swallowed. + + Args: + agent_id: Agent whose conversation to consolidate. + summary: High-level summary of the conversation. + """ + if self._memory_service is None: + return + + try: + from layerlens.instrument._vendored.memory_models import MemoryEntry + + timestamp = int(time.time()) + entry = MemoryEntry( + org_id=getattr(self._stratix, "org_id", ""), + agent_id=agent_id, + memory_type="semantic", + key=f"conversation_summary_{timestamp}", + content=summary[:4000], + importance=0.7, + metadata={"source": "autogen_adapter", "type": "conversation_consolidation"}, + ) + self._memory_service.store(entry) + except Exception: + logger.debug( + "Failed to store conversation summary for agent %s", + agent_id, + exc_info=True, + ) + + # --- Agent config emission --- + + def _emit_agent_config(self, agent: Any) -> None: + """Emit environment.config for an agent on first encounter.""" + name = getattr(agent, "name", None) or str(agent) + with self._adapter_lock: + if name in self._seen_agents: + return + self._seen_agents.add(name) + + metadata = self._metadata_extractor.extract(agent) + + self.emit_dict_event( + "environment.config", + { + "framework": "autogen", + **metadata, + }, + ) + + # --- Internal helpers --- + + def _safe_serialize(self, value: Any) -> Any: + """Safely serialize a value for events.""" + try: + if value is None: + return None + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + if isinstance(value, dict): + return dict(value) + if isinstance(value, (str, int, float, bool)): + return value + return str(value) + except Exception: + return str(value) + + def _message_content(self, message: Any) -> str: + """Extract string content from a message.""" + if message is None: + return "" + if isinstance(message, str): + return message + if isinstance(message, dict): + return str(message.get("content", message)) + return str(message) + + def _truncate(self, text: str, max_len: int = 500) -> str: + """Truncate text to max_len.""" + if len(text) <= max_len: + return text + return text[:max_len] + "..." + + def _extract_system_message(self, agent: Any) -> str | None: + """Extract system message from agent config.""" + try: + # AutoGen 0.2.x: agent.system_message + sys_msg = getattr(agent, "system_message", None) + if sys_msg: + return str(sys_msg) + # AutoGen 0.4+/agentchat: agent._system_messages + sys_msgs = getattr(agent, "_system_messages", None) + if sys_msgs and isinstance(sys_msgs, list) and sys_msgs: + first = sys_msgs[0] + content = getattr(first, "content", None) or str(first) + return str(content) + except Exception: + pass + return None + + def _extract_model_name(self, agent: Any) -> str | None: + """Extract model name from agent's llm_config.""" + try: + llm_config = getattr(agent, "llm_config", None) + if not llm_config or not isinstance(llm_config, dict): + return None + if "model" in llm_config: + return llm_config["model"] # type: ignore[no-any-return] + config_list = llm_config.get("config_list", []) + if config_list and isinstance(config_list[0], dict): + return config_list[0].get("model") + except Exception: + pass + return None + + def _extract_token_usage_from_reply(self, reply: Any) -> dict[str, Any] | None: + """Extract token usage from a reply if available.""" + if reply is None: + return None + try: + usage = getattr(reply, "usage", None) + if usage: + if isinstance(usage, dict): + return { + "tokens_prompt": usage.get("prompt_tokens"), + "tokens_completion": usage.get("completion_tokens"), + } + return { + "tokens_prompt": getattr(usage, "prompt_tokens", None), + "tokens_completion": getattr(usage, "completion_tokens", None), + } + except Exception: + pass + return None + + @staticmethod + def _detect_framework_version() -> str | None: + try: + import autogen # type: ignore[import-not-found,unused-ignore] + + return getattr(autogen, "__version__", None) + except ImportError: + return None diff --git a/src/layerlens/instrument/adapters/frameworks/autogen/metadata.py b/src/layerlens/instrument/adapters/frameworks/autogen/metadata.py new file mode 100644 index 00000000..2d9ab934 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/autogen/metadata.py @@ -0,0 +1,94 @@ +""" +AutoGen Agent Metadata Extraction + +Extracts agent metadata for L4a (environment.config) emission. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class AutoGenAgentMetadataExtractor: + """Extracts AutoGen agent metadata for environment.config emission.""" + + def extract(self, agent: Any) -> dict[str, Any]: + """ + Extract metadata from an AutoGen ConversableAgent. + + Args: + agent: An AutoGen ConversableAgent instance + + Returns: + Dict of agent metadata + """ + metadata: dict[str, Any] = {} + + # Agent name + try: + metadata["name"] = getattr(agent, "name", str(agent)) + except Exception: + metadata["name"] = "" + + # System message + try: + system_message = getattr(agent, "system_message", None) + if system_message is not None: + metadata["system_message"] = ( + system_message[:500] if len(system_message) > 500 else system_message + ) + except Exception: + pass + + # Human input mode + try: + him = getattr(agent, "human_input_mode", None) + if him is not None: + metadata["human_input_mode"] = him + except Exception: + pass + + # LLM config + try: + llm_config = getattr(agent, "llm_config", None) + if llm_config and isinstance(llm_config, dict): + safe_config: dict[str, Any] = {} + if "model" in llm_config: + safe_config["model"] = llm_config["model"] + if "config_list" in llm_config: + models = [] + for cfg in llm_config["config_list"]: + if isinstance(cfg, dict) and "model" in cfg: + models.append(cfg["model"]) + if models: + safe_config["models"] = models + if "temperature" in llm_config: + safe_config["temperature"] = llm_config["temperature"] + metadata["llm_config"] = safe_config + except Exception: + pass + + # Max consecutive auto reply + try: + max_reply = getattr(agent, "max_consecutive_auto_reply", None) + if max_reply is not None: + metadata["max_consecutive_auto_reply"] = max_reply + except Exception: + pass + + # Code execution config + try: + code_config = getattr(agent, "code_execution_config", None) + if code_config and isinstance(code_config, dict): + safe_code_config: dict[str, Any] = {} + for key in ("work_dir", "use_docker", "timeout", "last_n_messages"): + if key in code_config: + safe_code_config[key] = code_config[key] + metadata["code_execution_config"] = safe_code_config + except Exception: + pass + + return metadata diff --git a/src/layerlens/instrument/adapters/frameworks/autogen/wrappers.py b/src/layerlens/instrument/adapters/frameworks/autogen/wrappers.py new file mode 100644 index 00000000..cfa23f9f --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/autogen/wrappers.py @@ -0,0 +1,155 @@ +""" +AutoGen Method Wrappers + +Creates traced versions of ConversableAgent methods that intercept calls +and route events to the AutoGenAdapter lifecycle hooks. + +All wrappers preserve the original method's behavior and handle adapter +exceptions silently to prevent tracing from breaking the application. +""" + +from __future__ import annotations + +import time +import logging +from typing import TYPE_CHECKING, Any +from collections.abc import Callable + +if TYPE_CHECKING: + from layerlens.instrument.adapters.frameworks.autogen.lifecycle import AutoGenAdapter + +logger = logging.getLogger(__name__) + + +def create_traced_send( + adapter: AutoGenAdapter, + agent: Any, + original_send: Callable[..., Any], +) -> Callable[..., Any]: + """ + Create a traced version of agent.send(). + + Captures the message being sent and the recipient, then delegates + to the original send method. + """ + + def traced_send(message: Any, recipient: Any, **kwargs: Any) -> Any: + try: + adapter.on_send(sender=agent, message=message, recipient=recipient) + except Exception: + logger.warning("Error in traced send pre-hook", exc_info=True) + + return original_send(message, recipient, **kwargs) + + traced_send._layerlens_original = original_send # type: ignore[attr-defined] + return traced_send + + +def create_traced_receive( + adapter: AutoGenAdapter, + agent: Any, + original_receive: Callable[..., Any], +) -> Callable[..., Any]: + """ + Create a traced version of agent.receive(). + + Captures the received message and sender, then delegates + to the original receive method. + """ + + def traced_receive(message: Any, sender: Any, **kwargs: Any) -> Any: + try: + adapter.on_receive(receiver=agent, message=message, sender=sender) + except Exception: + logger.warning("Error in traced receive pre-hook", exc_info=True) + + return original_receive(message, sender, **kwargs) + + traced_receive._layerlens_original = original_receive # type: ignore[attr-defined] + return traced_receive + + +def create_traced_generate_reply( + adapter: AutoGenAdapter, + agent: Any, + original_generate_reply: Callable[..., Any], +) -> Callable[..., Any]: + """ + Create a traced version of agent.generate_reply(). + + Captures timing and the generated reply, then delegates to the + original method. + """ + + def traced_generate_reply(messages: Any = None, sender: Any = None, **kwargs: Any) -> Any: + start_ns = time.time_ns() + error: Exception | None = None + + try: + reply = original_generate_reply(messages=messages, sender=sender, **kwargs) + except Exception as exc: + error = exc + reply = None + raise + finally: + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + if error is not None: + # Emit model.invoke with error information for failed calls + adapter.emit_dict_event( + "model.invoke", + { + "framework": "autogen", + "agent": getattr(agent, "name", str(agent)), + "model": adapter._extract_model_name(agent), + "latency_ms": elapsed_ms, + "error": str(error), + }, + ) + else: + adapter.on_generate_reply( + agent=agent, + messages=messages, + reply=reply, + latency_ms=elapsed_ms, + ) + except Exception: + logger.warning("Error in traced generate_reply post-hook", exc_info=True) + + return reply + + traced_generate_reply._layerlens_original = original_generate_reply # type: ignore[attr-defined] + return traced_generate_reply + + +def create_traced_execute_code( + adapter: AutoGenAdapter, + agent: Any, + original_execute_code: Callable[..., Any], +) -> Callable[..., Any]: + """ + Create a traced version of agent.execute_code_blocks(). + + Captures code blocks, execution result, and timing. + """ + + def traced_execute_code(code_blocks: Any, **kwargs: Any) -> Any: + start_ns = time.time_ns() + + result = original_execute_code(code_blocks, **kwargs) + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + adapter.on_execute_code( + agent=agent, + code_blocks=code_blocks, + result=result, + latency_ms=elapsed_ms, + ) + except Exception: + logger.warning("Error in traced execute_code post-hook", exc_info=True) + + return result + + traced_execute_code._layerlens_original = original_execute_code # type: ignore[attr-defined] + return traced_execute_code diff --git a/src/layerlens/instrument/adapters/frameworks/crewai/__init__.py b/src/layerlens/instrument/adapters/frameworks/crewai/__init__.py new file mode 100644 index 00000000..4a1c23d4 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai/__init__.py @@ -0,0 +1,66 @@ +""" +STRATIX CrewAI Adapter + +Integrates STRATIX tracing with the CrewAI agent framework. + +Usage: + from layerlens.instrument.adapters.frameworks.crewai import ( + CrewAIAdapter, + LayerLensCrewCallback, + instrument_crew, + ) + + adapter = CrewAIAdapter(stratix=stratix_instance) + adapter.connect() + instrumented_crew = adapter.instrument_crew(my_crew) + result = instrumented_crew.kickoff() +""" + +from __future__ import annotations + +from typing import Any + +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat, requires_pydantic + +# Round-2 deliberation item 20: CrewAI >=0.30 pins ``pydantic = "^2"``; +# fail fast under v1. +requires_pydantic(PydanticCompat.V2_ONLY) + +from layerlens.instrument.adapters.frameworks.crewai.metadata import AgentMetadataExtractor +from layerlens.instrument.adapters.frameworks.crewai.callbacks import LayerLensCrewCallback +from layerlens.instrument.adapters.frameworks.crewai.lifecycle import CrewAIAdapter +from layerlens.instrument.adapters.frameworks.crewai.delegation import CrewDelegationTracker + +# Registry lazy-loading convention +ADAPTER_CLASS = CrewAIAdapter + + +def instrument_crew(crew: Any, stratix: Any = None, capture_config: dict[str, Any] = None) -> Any: # type: ignore[assignment] + """ + Convenience function to instrument a CrewAI crew with STRATIX tracing. + + Args: + crew: A CrewAI Crew instance + stratix: STRATIX SDK instance + capture_config: CaptureConfig to use + + Returns: + The instrumented crew + """ + adapter = CrewAIAdapter(stratix=stratix, capture_config=capture_config) # type: ignore[arg-type] + adapter.connect() + return adapter.instrument_crew(crew) + + +__all__ = [ + "CrewAIAdapter", + "LayerLensCrewCallback", + "CrewDelegationTracker", + "AgentMetadataExtractor", + "instrument_crew", + "ADAPTER_CLASS", +] + + +# Backward-compat aliases for users coming from ateam. +STRATIXCrewCallback = LayerLensCrewCallback # noqa: N816 - backward-compat alias for ateam users diff --git a/src/layerlens/instrument/adapters/frameworks/crewai/callbacks.py b/src/layerlens/instrument/adapters/frameworks/crewai/callbacks.py new file mode 100644 index 00000000..5a8ce077 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai/callbacks.py @@ -0,0 +1,242 @@ +""" +CrewAI Callback Handler + +Routes CrewAI callback events to the CrewAIAdapter lifecycle hooks. +All methods wrap adapter calls in try/except to prevent tracing from +crashing the crew execution. +""" + +from __future__ import annotations + +import time +import logging +import threading +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from layerlens.instrument.adapters.frameworks.crewai.lifecycle import CrewAIAdapter + +logger = logging.getLogger(__name__) + + +class LayerLensCrewCallback: + """ + CrewAI callback handler that routes events to CrewAIAdapter. + + Implements the CrewAI callback protocol and translates framework + callbacks into STRATIX lifecycle hook calls. + """ + + def __init__(self, adapter: CrewAIAdapter) -> None: + self._adapter = adapter + self._lock = threading.Lock() + self._seen_agents: set[str] = set() + self._task_counter: int = 0 + self._current_task_start_ns: int = 0 + + # --- CrewAI callback methods --- + + def on_crew_start(self, crew: Any = None, inputs: Any = None) -> None: + """Called when crew execution begins.""" + try: + self._adapter.on_crew_start(crew_input=inputs) + except Exception: + logger.warning("Error in on_crew_start callback", exc_info=True) + + def on_crew_end(self, crew: Any = None, output: Any = None) -> None: + """Called when crew execution completes.""" + try: + self._adapter.on_crew_end(crew_output=output) + except Exception: + logger.warning("Error in on_crew_end callback", exc_info=True) + + def on_task_start(self, task: Any = None) -> None: + """Called when a task begins execution.""" + try: + with self._lock: + self._task_counter += 1 + self._current_task_start_ns = time.time_ns() + task_counter = self._task_counter + + description = getattr(task, "description", None) or "" + expected_output = getattr(task, "expected_output", None) + agent = getattr(task, "agent", None) + agent_role = getattr(agent, "role", None) if agent else None + + # Emit agent config on first encounter + if agent and agent_role: + with self._lock: + seen = agent_role in self._seen_agents + if not seen: + self._seen_agents.add(agent_role) + if not seen: + self._adapter._emit_agent_config(agent) + + self._adapter.on_task_start( + task_description=description, + agent_role=agent_role, + expected_output=expected_output, + task_order=task_counter, + ) + except Exception: + logger.warning("Error in on_task_start callback", exc_info=True) + + def on_task_end(self, task: Any = None, output: Any = None) -> None: + """Called when a task completes.""" + try: + agent = getattr(task, "agent", None) if task else None + agent_role = getattr(agent, "role", None) if agent else None + + self._adapter.on_task_end( + task_output=output, + agent_role=agent_role, + task_order=self._task_counter, + ) + except Exception: + logger.warning("Error in on_task_end callback", exc_info=True) + + def on_agent_action(self, agent: Any = None, action: Any = None) -> None: + """Called when an agent takes an action.""" + try: + role = getattr(agent, "role", None) if agent else None + + # Emit agent config on first encounter + if agent and role: + with self._lock: + seen = role in self._seen_agents + if not seen: + self._seen_agents.add(role) + if not seen: + self._adapter._emit_agent_config(agent) + except Exception: + logger.warning("Error in on_agent_action callback", exc_info=True) + + def on_agent_end(self, agent: Any = None, output: Any = None) -> None: + """Called when an agent finishes processing.""" + try: + role = getattr(agent, "role", None) if agent else None + self._adapter.emit_dict_event( + "agent.state.change", + { + "framework": "crewai", + "agent_role": role, + "event_subtype": "agent_complete", + "output": self._adapter._safe_serialize(output), + }, + ) + except Exception: + logger.warning("Error in on_agent_end callback", exc_info=True) + + def on_tool_use( + self, + agent: Any = None, + tool_name: str = "", + tool_input: Any = None, + tool_output: Any = None, + ) -> None: + """Called when an agent uses a tool.""" + try: + self._adapter.on_tool_use( + tool_name=tool_name, + tool_input=tool_input, + tool_output=tool_output, + ) + except Exception: + logger.warning("Error in on_tool_use callback", exc_info=True) + + def on_llm_call(self, agent: Any = None, response: Any = None) -> None: + """Called when an LLM call completes.""" + try: + provider = None + model = None + tokens_prompt = None + tokens_completion = None + + if response is not None: + # Try to extract model info from response + model = getattr(response, "model", None) or getattr(response, "model_name", None) + provider = self._detect_provider(response) + + # Token usage + usage = getattr(response, "usage", None) + if usage: + if isinstance(usage, dict): + tokens_prompt = usage.get("prompt_tokens") + tokens_completion = usage.get("completion_tokens") + else: + tokens_prompt = getattr(usage, "prompt_tokens", None) + tokens_completion = getattr(usage, "completion_tokens", None) + + self._adapter.on_llm_call( + provider=provider, + model=model, + tokens_prompt=tokens_prompt, + tokens_completion=tokens_completion, + ) + except Exception: + logger.warning("Error in on_llm_call callback", exc_info=True) + + # --- Step/task callbacks (attached to crew) --- + + def on_step(self, step_output: Any = None) -> None: + """ + CrewAI step_callback handler. + + Called after each agent step. Routes to appropriate handler. + """ + try: + # Extract tool usage from step output if present + tool_name = getattr(step_output, "tool", None) + if tool_name: + tool_input = getattr(step_output, "tool_input", None) + tool_output = getattr(step_output, "result", None) + self._adapter.on_tool_use( + tool_name=tool_name, + tool_input=tool_input, + tool_output=tool_output, + ) + + # Check for delegation + delegated_to = getattr(step_output, "delegated_to", None) + if delegated_to: + delegated_from = getattr(step_output, "agent", None) + from_role = ( + getattr(delegated_from, "role", "unknown") if delegated_from else "unknown" + ) + to_role = ( + getattr(delegated_to, "role", str(delegated_to)) if delegated_to else "unknown" + ) + context = getattr(step_output, "result", None) + self._adapter.on_delegation(from_role, to_role, context) + except Exception: + logger.warning("Error in on_step callback", exc_info=True) + + def on_task_complete(self, task_output: Any = None) -> None: + """ + CrewAI task_callback handler. + + Called after each task completes. + """ + try: + self._adapter.on_task_end(task_output=task_output) + except Exception: + logger.warning("Error in on_task_complete callback", exc_info=True) + + # --- Internal helpers --- + + def _detect_provider(self, response: Any) -> str | None: + """Detect LLM provider from response object.""" + try: + class_name = type(response).__module__ or "" + lower = class_name.lower() + if "openai" in lower: + return "openai" + if "anthropic" in lower: + return "anthropic" + if "google" in lower: + return "google" + if "cohere" in lower: + return "cohere" + except Exception: + pass + return None diff --git a/src/layerlens/instrument/adapters/frameworks/crewai/delegation.py b/src/layerlens/instrument/adapters/frameworks/crewai/delegation.py new file mode 100644 index 00000000..76dffd2a --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai/delegation.py @@ -0,0 +1,83 @@ +""" +CrewAI Delegation Detection + +Tracks delegation in hierarchical CrewAI processes and emits agent.handoff events. +""" + +from __future__ import annotations + +import hashlib +import logging +import threading +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from layerlens.instrument.adapters.frameworks.crewai.lifecycle import CrewAIAdapter + +logger = logging.getLogger(__name__) + + +class CrewDelegationTracker: + """Tracks delegation in hierarchical CrewAI processes.""" + + def __init__(self, adapter: CrewAIAdapter) -> None: + self._adapter = adapter + self._lock = threading.Lock() + self._delegation_count = 0 + + @property + def delegation_count(self) -> int: + return self._delegation_count + + def track_delegation( + self, + from_agent: str, + to_agent: str, + context: Any = None, + ) -> None: + """ + Record a delegation from one agent to another. + + Emits an agent.handoff (cross-cutting, always enabled) event. + + Args: + from_agent: Role/name of the delegating agent + to_agent: Role/name of the delegate agent + context: Optional context passed with the delegation + """ + with self._lock: + self._delegation_count += 1 + delegation_seq = self._delegation_count + + context_str = self._summarize_context(context) + context_hash = self._hash_context(context_str) + + try: + self._adapter.emit_dict_event( + "agent.handoff", + { + "from_agent": from_agent, + "to_agent": to_agent, + "reason": "delegation", + "context_hash": context_hash, + "context_preview": context_str[:500] if context_str else None, + "delegation_seq": delegation_seq, + }, + ) + except Exception: + logger.warning("Failed to emit delegation handoff", exc_info=True) + + def _summarize_context(self, context: Any) -> str: + """Safely summarize delegation context.""" + if context is None: + return "" + try: + if isinstance(context, str): + return context + return str(context) + except Exception: + return "" + + def _hash_context(self, context_str: str) -> str: + """SHA-256 hash of context string.""" + return hashlib.sha256(context_str.encode("utf-8", errors="replace")).hexdigest() diff --git a/src/layerlens/instrument/adapters/frameworks/crewai/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/crewai/lifecycle.py new file mode 100644 index 00000000..36941bee --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai/lifecycle.py @@ -0,0 +1,504 @@ +""" +STRATIX CrewAI Lifecycle Hooks + +Provides the main CrewAIAdapter class and crew instrumentation. +""" + +from __future__ import annotations + +import time +import uuid +import logging +import threading +from typing import Any + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters.frameworks.crewai.metadata import AgentMetadataExtractor +from layerlens.instrument.adapters.frameworks.crewai.delegation import CrewDelegationTracker + +logger = logging.getLogger(__name__) + + +class CrewAIAdapter(BaseAdapter): + """ + Main adapter for integrating STRATIX with CrewAI. + + Instruments CrewAI crews, agents, and tasks to emit STRATIX telemetry events. + Uses the CrewAI callback protocol (v0.41+) via LayerLensCrewCallback. + + Supports both new-style (stratix, capture_config) and legacy (stratix_instance) + constructor parameters. + + Usage: + adapter = CrewAIAdapter(stratix=stratix_instance) + adapter.connect() + instrumented_crew = adapter.instrument_crew(my_crew) + result = instrumented_crew.kickoff() + """ + + FRAMEWORK = "crewai" + VERSION = "0.1.0" + # CrewAI >=0.30 (pyproject pin: crewai>=0.30,<0.90) is Pydantic v2 + # only — see crewai's pyproject which pins ``pydantic = "^2.4.2"``. + # Importing crewai under v1 fails inside crewai's own model layer. + requires_pydantic = PydanticCompat.V2_ONLY + + def __init__( + self, + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + # Legacy param + stratix_instance: Any | None = None, + memory_service: Any | None = None, + ) -> None: + resolved_stratix = stratix or stratix_instance + super().__init__(stratix=resolved_stratix, capture_config=capture_config) + + self._metadata_extractor = AgentMetadataExtractor() + self._delegation_tracker = CrewDelegationTracker(self) + self._adapter_lock = threading.Lock() + self._seen_agents: set[str] = set() + self._crew_start_ns: int = 0 + self._framework_version: str | None = None + self._memory_service = memory_service + + # --- BaseAdapter lifecycle --- + + def connect(self) -> None: + """Verify CrewAI is importable and mark as connected.""" + try: + import crewai # type: ignore[import-not-found,unused-ignore] # noqa: F401 + + version = getattr(crewai, "__version__", "unknown") + logger.debug("CrewAI %s detected", version) + except ImportError: + logger.debug("CrewAI not installed; adapter usable in mock/test mode") + self._framework_version = self._detect_framework_version() + self._connected = True + self._status = AdapterStatus.HEALTHY + + def disconnect(self) -> None: + """Flush and disconnect.""" + self._connected = False + self._status = AdapterStatus.DISCONNECTED + + def health_check(self) -> AdapterHealth: + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + framework_version=self._framework_version, + adapter_version=self.VERSION, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name="CrewAIAdapter", + version=self.VERSION, + framework=self.FRAMEWORK, + framework_version=self._framework_version, + capabilities=[ + AdapterCapability.TRACE_TOOLS, + AdapterCapability.TRACE_MODELS, + AdapterCapability.TRACE_STATE, + AdapterCapability.TRACE_HANDOFFS, + ], + description="LayerLens adapter for CrewAI agent framework", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + return ReplayableTrace( + adapter_name="CrewAIAdapter", + framework=self.FRAMEWORK, + trace_id=str(uuid.uuid4()), + events=list(self._trace_events), + state_snapshots=[], + config={ + "capture_config": self._capture_config.model_dump(), + }, + ) + + # --- Crew instrumentation --- + + def instrument_crew(self, crew: Any) -> Any: + """ + Instrument a CrewAI Crew with STRATIX tracing. + + Registers LayerLensCrewCallback on the crew. Records process type + and agent metadata. + + Args: + crew: A CrewAI Crew instance + + Returns: + The modified crew (same object, with callback attached) + """ + from layerlens.instrument.adapters.frameworks.crewai.callbacks import LayerLensCrewCallback + + callback = LayerLensCrewCallback(adapter=self) + + # Record process type + process_type = getattr(crew, "process", None) + if process_type is not None: + process_type = str(process_type) + + # Attach callback - CrewAI supports step_callback and task_callback + try: + if hasattr(crew, "step_callback"): + crew.step_callback = callback.on_step + if hasattr(crew, "task_callback"): + crew.task_callback = callback.on_task_complete + except Exception: + logger.debug("Could not attach callbacks to crew", exc_info=True) + + # Store callback reference for lifecycle hooks + crew._stratix_callback = callback + crew._stratix_adapter = self + + # Extract agent metadata on first encounter + agents = getattr(crew, "agents", []) or [] + for agent in agents: + self._emit_agent_config(agent, process_type) + + return crew + + # --- Lifecycle hooks (called by callback) --- + + def on_crew_start(self, crew_input: Any = None) -> None: + """ + Handle crew execution start. + + Emits agent.input (L1). + """ + with self._adapter_lock: + self._crew_start_ns = time.time_ns() + + self.emit_dict_event( + "agent.input", + { + "framework": "crewai", + "input": self._safe_serialize(crew_input), + "timestamp_ns": self._crew_start_ns, + }, + ) + + def on_crew_end( + self, + crew_output: Any = None, + error: Exception | None = None, + ) -> None: + """ + Handle crew execution end. + + Emits agent.output (L1). + """ + end_ns = time.time_ns() + duration_ns = end_ns - self._crew_start_ns if self._crew_start_ns else 0 + + payload: dict[str, Any] = { + "framework": "crewai", + "output": self._safe_serialize(crew_output), + "duration_ns": duration_ns, + } + if error: + payload["error"] = str(error) + + self.emit_dict_event("agent.output", payload) + + def on_task_start( + self, + task_description: str, + agent_role: str | None = None, + expected_output: str | None = None, + task_order: int | None = None, + ) -> None: + """ + Handle task start. + + Emits agent.code (L2) as dict event with task metadata. + """ + payload: dict[str, Any] = { + "framework": "crewai", + "task_description": task_description, + "event_subtype": "task_start", + } + if agent_role: + payload["agent_role"] = agent_role + if expected_output: + payload["expected_output"] = expected_output + if task_order is not None: + payload["task_order"] = task_order + + self.emit_dict_event("agent.code", payload) + + def on_task_end( + self, + task_output: Any = None, + agent_role: str | None = None, + task_order: int | None = None, + error: Exception | None = None, + ) -> None: + """ + Handle task completion. + + Emits agent.state.change (cross-cutting) and cost.record (cross-cutting) + if token costs are available. + """ + payload: dict[str, Any] = { + "framework": "crewai", + "task_output": self._safe_serialize(task_output), + "event_subtype": "task_complete", + } + if agent_role: + payload["agent_role"] = agent_role + if task_order is not None: + payload["task_order"] = task_order + if error: + payload["error"] = str(error) + + self.emit_dict_event("agent.state.change", payload) + + # Emit cost record if token usage available + token_usage = self._extract_token_usage(task_output) + if token_usage: + self.emit_dict_event( + "cost.record", + { + "framework": "crewai", + "agent_role": agent_role, + **token_usage, + }, + ) + + def on_tool_use( + self, + tool_name: str, + tool_input: Any = None, + tool_output: Any = None, + error: Exception | None = None, + latency_ms: float | None = None, + ) -> None: + """ + Handle tool usage. + + Emits tool.call (L5a). + """ + payload: dict[str, Any] = { + "framework": "crewai", + "tool_name": tool_name, + "tool_input": self._safe_serialize(tool_input), + "tool_output": self._safe_serialize(tool_output), + } + if error: + payload["error"] = str(error) + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + self.emit_dict_event("tool.call", payload) + + def on_llm_call( + self, + provider: str | None = None, + model: str | None = None, + tokens_prompt: int | None = None, + tokens_completion: int | None = None, + latency_ms: float | None = None, + messages: list[dict[str, str]] | None = None, + ) -> None: + """ + Handle LLM invocation. + + Emits model.invoke (L3). + """ + payload: dict[str, Any] = { + "framework": "crewai", + } + if provider: + payload["provider"] = provider + if model: + payload["model"] = model + if tokens_prompt is not None: + payload["tokens_prompt"] = tokens_prompt + if tokens_completion is not None: + payload["tokens_completion"] = tokens_completion + if latency_ms is not None: + payload["latency_ms"] = latency_ms + if self._capture_config.capture_content and messages: + payload["messages"] = messages + + self.emit_dict_event("model.invoke", payload) + + def on_delegation( + self, + from_agent: str, + to_agent: str, + context: Any = None, + ) -> None: + """ + Handle agent delegation. + + Emits agent.handoff (cross-cutting, always enabled). + """ + self._delegation_tracker.track_delegation(from_agent, to_agent, context) + + # --- Memory integration --- + + def inject_memory_context( + self, + agent_id: str, + task_context: str, + ) -> str: + """Retrieve relevant semantic memories and prepend them to the task context. + + When no ``memory_service`` is configured the original ``task_context`` + is returned unmodified (backward compatible). + + Args: + agent_id: Agent whose memories to retrieve. + task_context: Original task context string. + + Returns: + Enriched context with relevant memories prepended, or the + original context when the memory service is unavailable or + no memories are found. + """ + if self._memory_service is None: + return task_context + + try: + memories = self._memory_service.search(agent_id, task_context, limit=5) + if not memories: + return task_context + + memory_lines = [f"- [{m.key}]: {m.content[:200]}" for m in memories] + header = "Relevant memories:\n" + "\n".join(memory_lines) + "\n\n" + return header + task_context + except Exception: + logger.debug( + "Failed to inject memory context for agent %s", + agent_id, + exc_info=True, + ) + return task_context + + def store_task_result( + self, + agent_id: str, + task_name: str, + result: Any, + ) -> None: + """Store a task result as procedural memory. + + Only active when ``memory_service`` is provided. Failures are + logged and swallowed. + + Args: + agent_id: Agent that completed the task. + task_name: Name or description of the task. + result: Task result to persist. + """ + if self._memory_service is None: + return + + try: + from layerlens.instrument._vendored.memory_models import MemoryEntry + + content = self._safe_serialize(result) + entry = MemoryEntry( + org_id=getattr(self._stratix, "org_id", ""), + agent_id=agent_id, + memory_type="procedural", + key=f"task_result_{task_name}", + content=str(content), + importance=0.6, + metadata={"source": "crewai_adapter", "task_name": task_name}, + ) + self._memory_service.store(entry) + except Exception: + logger.debug( + "Failed to store task result memory for agent %s task %s", + agent_id, + task_name, + exc_info=True, + ) + + # --- Agent config emission --- + + def _emit_agent_config( + self, + agent: Any, + process_type: str | None = None, + ) -> None: + """Emit environment.config for an agent on first encounter.""" + role = getattr(agent, "role", None) or str(agent) + with self._adapter_lock: + if role in self._seen_agents: + return + self._seen_agents.add(role) + + metadata = self._metadata_extractor.extract(agent) + if process_type: + metadata["process_type"] = process_type + + self.emit_dict_event( + "environment.config", + { + "framework": "crewai", + "agent_role": role, + **metadata, + }, + ) + + # --- Internal helpers --- + + def _safe_serialize(self, value: Any) -> Any: + """Safely serialize a value for events.""" + try: + if value is None: + return None + if hasattr(value, "model_dump"): + return value.model_dump() + if hasattr(value, "dict"): + return value.dict() + if isinstance(value, dict): + return dict(value) + if isinstance(value, (str, int, float, bool)): + return value + return str(value) + except Exception: + return str(value) + + def _extract_token_usage(self, task_output: Any) -> dict[str, Any] | None: + """Extract token usage from task output if available.""" + if task_output is None: + return None + try: + usage = getattr(task_output, "token_usage", None) + if usage and isinstance(usage, dict): + return { + "tokens_prompt": usage.get("prompt_tokens"), + "tokens_completion": usage.get("completion_tokens"), + "tokens_total": usage.get("total_tokens"), + } + except Exception: + pass + return None + + @staticmethod + def _detect_framework_version() -> str | None: + try: + import crewai # type: ignore[import-not-found,unused-ignore] + + return getattr(crewai, "__version__", None) + except ImportError: + return None diff --git a/src/layerlens/instrument/adapters/frameworks/crewai/metadata.py b/src/layerlens/instrument/adapters/frameworks/crewai/metadata.py new file mode 100644 index 00000000..85835b31 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/crewai/metadata.py @@ -0,0 +1,65 @@ +""" +CrewAI Agent Metadata Extraction + +Extracts and caches agent metadata for L4a (environment.config) emission. +""" + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class AgentMetadataExtractor: + """Extracts and caches CrewAI agent metadata for L4a emission.""" + + def extract(self, agent: Any) -> dict[str, Any]: + """ + Extract metadata from a CrewAI Agent. + + Args: + agent: A CrewAI Agent instance + + Returns: + Dict of agent metadata + """ + metadata: dict[str, Any] = {} + + for attr in ( + "role", + "goal", + "backstory", + "verbose", + "allow_delegation", + "max_iter", + "memory", + ): + try: + val = getattr(agent, attr, None) + if val is not None: + metadata[attr] = val + except Exception: + pass + + # Extract tool names + try: + tools = getattr(agent, "tools", None) + if tools: + metadata["tools"] = [getattr(t, "name", str(t)) for t in tools] + except Exception: + pass + + # Extract LLM model info + try: + llm = getattr(agent, "llm", None) + if llm is not None: + model_name = ( + getattr(llm, "model_name", None) or getattr(llm, "model", None) or str(llm) + ) + metadata["llm_model"] = model_name + except Exception: + pass + + return metadata diff --git a/src/layerlens/instrument/adapters/frameworks/langchain/__init__.py b/src/layerlens/instrument/adapters/frameworks/langchain/__init__.py new file mode 100644 index 00000000..be91bc5f --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain/__init__.py @@ -0,0 +1,55 @@ +""" +STRATIX LangChain Adapter + +Integrates STRATIX tracing with LangChain framework using callbacks. + +Usage: + from layerlens.instrument.adapters.frameworks.langchain import ( + LayerLensCallbackHandler, + instrument_chain, + instrument_agent, + ) + + # Create callback handler + handler = LayerLensCallbackHandler(stratix_instance) + + # Use with LangChain components + llm = ChatOpenAI(callbacks=[handler]) + chain = LLMChain(llm=llm, callbacks=[handler]) + + # Or instrument existing chain/agent + traced_chain = instrument_chain(chain, stratix_instance) +""" + +from __future__ import annotations + +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat, requires_pydantic + +# Round-2 deliberation item 20: fail fast under v1 with a clear message +# rather than letting LangChain raise an opaque ImportError mid-callback. +requires_pydantic(PydanticCompat.V2_ONLY) + +from layerlens.instrument.adapters.frameworks.langchain.state import LangChainMemoryAdapter +from layerlens.instrument.adapters.frameworks.langchain.agents import TracedAgent, instrument_agent +from layerlens.instrument.adapters.frameworks.langchain.chains import TracedChain, instrument_chain +from layerlens.instrument.adapters.frameworks.langchain.memory import TracedMemory, wrap_memory +from layerlens.instrument.adapters.frameworks.langchain.callbacks import LayerLensCallbackHandler + +# Registry lazy-loading convention +ADAPTER_CLASS = LayerLensCallbackHandler + +__all__ = [ + "LayerLensCallbackHandler", + "LangChainMemoryAdapter", + "TracedMemory", + "wrap_memory", + "instrument_chain", + "TracedChain", + "instrument_agent", + "TracedAgent", + "ADAPTER_CLASS", +] + + +# Backward-compat aliases for users coming from ateam. +STRATIXCallbackHandler = LayerLensCallbackHandler # noqa: N816 - backward-compat alias for ateam users diff --git a/src/layerlens/instrument/adapters/frameworks/langchain/agents.py b/src/layerlens/instrument/adapters/frameworks/langchain/agents.py new file mode 100644 index 00000000..8517cc24 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain/agents.py @@ -0,0 +1,381 @@ +""" +STRATIX LangChain Agent Instrumentation + +Provides automatic instrumentation for LangChain agents. +""" + +from __future__ import annotations + +import time +import logging +from typing import TYPE_CHECKING, Any +from dataclasses import field, dataclass + +from layerlens.instrument.adapters.frameworks.langchain.callbacks import LayerLensCallbackHandler + +if TYPE_CHECKING: + from layerlens.instrument.adapters._base.adapter import BaseAdapter + +logger = logging.getLogger(__name__) + + +@dataclass +class AgentStep: + """Represents a single step in agent execution.""" + + step_number: int + action: str | None = None + action_input: Any | None = None + observation: str | None = None + timestamp_ns: int | None = None + + +@dataclass +class AgentExecution: + """Tracks a complete agent execution.""" + + agent_type: str + start_time_ns: int + end_time_ns: int | None = None + input: str | dict[str, Any] | None = None + output: Any | None = None + steps: list[AgentStep] = field(default_factory=list) + error: str | None = None + + +class TracedAgent: + """ + Wrapper around a LangChain agent with STRATIX tracing. + + Captures: + - Agent input/output + - Intermediate reasoning steps + - Tool calls during execution + - LLM invocations + + Usage: + from langchain.agents import create_react_agent # type: ignore[import-untyped,unused-ignore] + + agent = create_react_agent(llm, tools, prompt) + traced_agent = TracedAgent(agent, stratix_instance) + + # Use as normal + result = traced_agent.invoke({"input": "What is the weather?"}) + """ + + def __init__( + self, + agent: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + ) -> None: + """ + Initialize the traced agent. + + Args: + agent: LangChain agent instance (AgentExecutor or similar) + stratix_instance: STRATIX SDK instance (legacy) + adapter: BaseAdapter instance (new-style) + """ + self._agent = agent + self._stratix = stratix_instance + self._adapter = adapter + self._handler = LayerLensCallbackHandler( + stratix=adapter._stratix if adapter else None, + capture_config=adapter.capture_config if adapter else None, + stratix_instance=stratix_instance, + ) + self._agent_type = type(agent).__name__ + self._executions: list[AgentExecution] = [] + self._current_execution: AgentExecution | None = None + self._step_counter = 0 + + def invoke( + self, + input: dict[str, Any] | str, + config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Invoke the agent with tracing. + + Args: + input: Agent input + config: Optional config + **kwargs: Additional arguments + + Returns: + Agent output + """ + execution = AgentExecution( + agent_type=self._agent_type, + start_time_ns=time.time_ns(), + input=input, + ) + self._executions.append(execution) + self._current_execution = execution + self._step_counter = 0 + + # Emit agent input event + self._emit_agent_input(input) + + # Inject callback handler + callbacks = kwargs.get("callbacks", []) + if self._handler not in callbacks: + callbacks = list(callbacks) + [self._handler] + kwargs["callbacks"] = callbacks + + try: + result = self._agent.invoke(input, config, **kwargs) + + execution.end_time_ns = time.time_ns() + execution.output = result + + # Emit agent output event + self._emit_agent_output(execution) + + return result # type: ignore[no-any-return] + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._emit_agent_output(execution) + raise + finally: + self._current_execution = None + + async def ainvoke( + self, + input: dict[str, Any] | str, + config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Async invoke the agent with tracing. + + Args: + input: Agent input + config: Optional config + **kwargs: Additional arguments + + Returns: + Agent output + """ + execution = AgentExecution( + agent_type=self._agent_type, + start_time_ns=time.time_ns(), + input=input, + ) + self._executions.append(execution) + self._current_execution = execution + self._step_counter = 0 + + self._emit_agent_input(input) + + callbacks = kwargs.get("callbacks", []) + if self._handler not in callbacks: + callbacks = list(callbacks) + [self._handler] + kwargs["callbacks"] = callbacks + + try: + result = await self._agent.ainvoke(input, config, **kwargs) + + execution.end_time_ns = time.time_ns() + execution.output = result + + self._emit_agent_output(execution) + + return result # type: ignore[no-any-return] + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._emit_agent_output(execution) + raise + finally: + self._current_execution = None + + def run(self, *args: Any, **kwargs: Any) -> str: + """ + Run the agent (deprecated method). + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Agent output string + """ + callbacks = kwargs.get("callbacks", []) + if self._handler not in callbacks: + callbacks = list(callbacks) + [self._handler] + kwargs["callbacks"] = callbacks + + return self._agent.run(*args, **kwargs) # type: ignore[no-any-return] + + def record_step( + self, + action: str | None = None, + action_input: Any = None, + observation: str | None = None, + ) -> None: + """ + Record an intermediate step. + + Called automatically by callback handler but can be + called manually for custom step tracking. + + Args: + action: The action taken + action_input: Input to the action + observation: Result of the action + """ + if self._current_execution is None: + return + + self._step_counter += 1 + step = AgentStep( + step_number=self._step_counter, + action=action, + action_input=action_input, + observation=observation, + timestamp_ns=time.time_ns(), + ) + self._current_execution.steps.append(step) + + def _emit_agent_input(self, input: Any) -> None: + """Emit agent.input event.""" + payload = { + "agent_type": self._agent_type, + "input": input, + "timestamp_ns": time.time_ns(), + } + + if self._adapter is not None: + try: + from layerlens.instrument._vendored.events import ( + MessageRole, + AgentInputEvent, + ) + + msg = str(input) if not isinstance(input, str) else input + typed_payload = AgentInputEvent.create(message=msg, role=MessageRole.HUMAN) + self._adapter.emit_event(typed_payload) + return + except Exception: + logger.debug("Typed event emission failed, falling back to legacy", exc_info=True) + + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit("agent.input", payload) + + def _emit_agent_output(self, execution: AgentExecution) -> None: + """Emit agent.output event.""" + duration_ns = (execution.end_time_ns or 0) - execution.start_time_ns + payload = { + "agent_type": execution.agent_type, + "input": execution.input, + "output": execution.output, + "num_steps": len(execution.steps), + "duration_ns": duration_ns, + "error": execution.error, + } + + if self._adapter is not None: + try: + from layerlens.instrument._vendored.events import AgentOutputEvent + + msg = str(execution.output) if execution.output else "" + typed_payload = AgentOutputEvent.create(message=msg) + self._adapter.emit_event(typed_payload) + return + except Exception: + logger.debug("Typed event emission failed, falling back to legacy", exc_info=True) + + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit("agent.output", payload) + + @property + def callback_handler(self) -> LayerLensCallbackHandler: + """Get the callback handler.""" + return self._handler + + @property + def executions(self) -> list[AgentExecution]: + """Get all recorded executions.""" + return self._executions + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to underlying agent.""" + return getattr(self._agent, name) + + +def instrument_agent( + agent: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, +) -> TracedAgent: + """ + Instrument a LangChain agent with STRATIX tracing. + + Args: + agent: LangChain agent instance + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + + Returns: + TracedAgent wrapper + """ + return TracedAgent(agent, stratix_instance, adapter=adapter) + + +class AgentTracer: + """ + Tracer for multiple agent executions. + + Provides a unified view of agent activity across + multiple invocations. + """ + + def __init__(self, stratix_instance: Any = None, adapter: BaseAdapter | None = None) -> None: + """ + Initialize the agent tracer. + + Args: + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + """ + self._stratix = stratix_instance + self._adapter = adapter + self._agents: dict[str, TracedAgent] = {} + self._all_executions: list[AgentExecution] = [] + + def trace(self, agent: Any, name: str | None = None) -> TracedAgent: + """ + Start tracing an agent. + + Args: + agent: LangChain agent + name: Optional name for the agent + + Returns: + TracedAgent wrapper + """ + agent_name = name or type(agent).__name__ + traced = TracedAgent(agent, self._stratix, adapter=self._adapter) + self._agents[agent_name] = traced + return traced + + def get_agent(self, name: str) -> TracedAgent | None: + """Get a traced agent by name.""" + return self._agents.get(name) + + def get_all_executions(self) -> list[AgentExecution]: + """Get all executions across all agents.""" + all_execs = [] + for agent in self._agents.values(): + all_execs.extend(agent.executions) + return sorted(all_execs, key=lambda e: e.start_time_ns) + + def get_total_steps(self) -> int: + """Get total number of steps across all executions.""" + return sum(len(e.steps) for agent in self._agents.values() for e in agent.executions) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain/callbacks.py b/src/layerlens/instrument/adapters/frameworks/langchain/callbacks.py new file mode 100644 index 00000000..f0854b52 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain/callbacks.py @@ -0,0 +1,801 @@ +""" +STRATIX LangChain Callback Handler + +Provides LangChain callback-based integration for STRATIX tracing. +""" + +from __future__ import annotations + +import time +import uuid +from uuid import UUID +from typing import Any +from dataclasses import dataclass +from collections.abc import Callable + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters._base.trace_container import SerializedTrace + + +@dataclass +class LLMCallContext: + """Context for tracking an LLM call.""" + + run_id: str + start_time_ns: int + model: str | None = None + provider: str | None = None + prompts: list[str] | None = None + invocation_params: dict[str, Any] | None = None + + +@dataclass +class ToolCallContext: + """Context for tracking a tool call.""" + + run_id: str + start_time_ns: int + tool_name: str + tool_input: str | dict[str, Any] | None = None + + +@dataclass +class AgentActionContext: + """Context for tracking an agent action.""" + + run_id: str + start_time_ns: int + action: str | None = None + action_input: Any | None = None + + +@dataclass +class ChainCallContext: + """Context for tracking a chain/node execution.""" + + run_id: str + start_time_ns: int + node_name: str | None = None + parent_run_id: str | None = None + + +class LayerLensCallbackHandler(BaseAdapter): + """ + LangChain callback handler that emits STRATIX events. + + Implements the LangChain callback interface to capture: + - model.invoke (L3) events from LLM calls + - tool.call (L5a) events from tool invocations + - agent.output events from agent actions + + Extends BaseAdapter for unified lifecycle and circuit-breaker support. + + Supports both new-style (stratix, capture_config) and legacy-style + (stratix_instance, boolean flags) parameters. + + Usage (new): + from stratix import STRATIX + from layerlens.instrument.adapters.frameworks.langchain import LayerLensCallbackHandler + + stratix = STRATIX(policy_ref="my-policy") + handler = LayerLensCallbackHandler(stratix=stratix) + handler.connect() + llm = ChatOpenAI(callbacks=[handler]) + + Usage (legacy — still supported): + handler = LayerLensCallbackHandler(stratix_instance=stratix) + llm = ChatOpenAI(callbacks=[handler]) + """ + + FRAMEWORK = "langchain" + VERSION = "0.1.0" + # LangChain >=0.2 (pyproject pin: langchain>=0.2,<0.4) migrated all + # internal models to Pydantic v2 — see langchain/langchain#21238 and + # langchain-core 0.2.0 release notes. Importing the langchain runtime + # under Pydantic v1 raises at import inside langchain itself. + requires_pydantic = PydanticCompat.V2_ONLY + + # LangChain callback protocol attributes — required by CallbackManager + raise_error: bool = False + ignore_llm: bool = False + ignore_chain: bool = False + ignore_agent: bool = False + ignore_chat_model: bool = False + ignore_retriever: bool = True + ignore_retry: bool = True + ignore_custom_event: bool = True + + def __init__( + self, + # New-style params + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + event_sinks: list[Any] | None = None, + graph_factory: Callable[[], Any] | None = None, + # Legacy params (backward compat) + stratix_instance: Any | None = None, + emit_llm_events: bool = True, + emit_tool_events: bool = True, + emit_agent_events: bool = True, + ) -> None: + """ + Initialize the callback handler. + + Args: + stratix: STRATIX SDK instance (new-style) + capture_config: CaptureConfig (new-style) + event_sinks: Optional list of EventSink instances for persistence + graph_factory: Optional callable that returns a fresh graph for replay + stratix_instance: STRATIX SDK instance (legacy) + emit_llm_events: Whether to emit model.invoke events (legacy) + emit_tool_events: Whether to emit tool.call events (legacy) + emit_agent_events: Whether to emit agent events (legacy) + """ + # Resolve STRATIX instance + resolved_stratix = stratix or stratix_instance + + # Map legacy booleans → CaptureConfig when any flag differs from default + if capture_config is None: + any_legacy = not emit_llm_events or not emit_tool_events or not emit_agent_events + if any_legacy or stratix_instance is not None: + capture_config = CaptureConfig( + l3_model_metadata=emit_llm_events, + l5a_tool_calls=emit_tool_events, + l1_agent_io=emit_agent_events, + ) + + super().__init__( + stratix=resolved_stratix, + capture_config=capture_config, + event_sinks=event_sinks, + ) + + # Graph factory for replay re-execution + self._graph_factory = graph_factory + + # Legacy compat: keep booleans accessible + self._emit_llm_events = emit_llm_events + self._emit_tool_events = emit_tool_events + self._emit_agent_events = emit_agent_events + + # Track active calls + self._llm_calls: dict[str, LLMCallContext] = {} + self._tool_calls: dict[str, ToolCallContext] = {} + self._agent_actions: dict[str, AgentActionContext] = {} + self._chain_calls: dict[str, ChainCallContext] = {} + self._run_to_node: dict[str, str] = {} # run_id -> langgraph node name + + # Track all events for debugging/testing + self._events: list[dict[str, Any]] = [] + + # --- BaseAdapter lifecycle --- + + def connect(self) -> None: + """Verify LangChain is importable and mark as connected.""" + try: + import langchain # type: ignore[import-not-found,unused-ignore] # noqa: F401 + + self._connected = True + self._status = AdapterStatus.HEALTHY + except ImportError: + self._connected = True + self._status = AdapterStatus.HEALTHY + + def disconnect(self) -> None: + self._close_sinks() + self._connected = False + self._status = AdapterStatus.DISCONNECTED + + def health_check(self) -> AdapterHealth: + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + framework_version=self._detect_framework_version(), + adapter_version=self.VERSION, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name="LayerLensCallbackHandler", + version=self.VERSION, + framework=self.FRAMEWORK, + framework_version=self._detect_framework_version(), + capabilities=[ + AdapterCapability.TRACE_TOOLS, + AdapterCapability.TRACE_MODELS, + AdapterCapability.REPLAY, + ], + description="LayerLens adapter for LangChain framework (callback-based)", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + trace_id = str(uuid.uuid4()) + return ReplayableTrace( + adapter_name="LayerLensCallbackHandler", + framework=self.FRAMEWORK, + trace_id=trace_id, + events=list(self._trace_events), + config={ + "capture_config": self._capture_config.model_dump(), + }, + ) + + # --- Replay execution --- + + async def execute_replay( + self, + inputs: dict[str, Any], + original_trace: Any, + request: Any, + replay_trace_id: str, + ) -> SerializedTrace: + """ + Re-execute through LangChain/LangGraph with a fresh graph. + + Requires a ``graph_factory`` to have been provided at construction. + + Args: + inputs: Reconstructed inputs for the replay. + original_trace: The original SerializedTrace. + request: The ReplayRequest. + replay_trace_id: ID for the new replay trace. + + Returns: + SerializedTrace from the replay execution. + + Raises: + NotImplementedError: If no graph_factory is registered. + """ + if self._graph_factory is None: + raise NotImplementedError("No graph_factory registered for replay") + + # Build a fresh graph and callback handler + graph = self._graph_factory() + replay_handler = LayerLensCallbackHandler(event_sinks=[]) + replay_handler.connect() + + try: + # Re-execute through LangGraph with new callbacks + graph.invoke(inputs, config={"callbacks": [replay_handler]}) + + return SerializedTrace.from_event_records( + events=list(replay_handler._trace_events), + trace_id=replay_trace_id, + metadata={ + "replay_of": original_trace.trace_id, + "framework": "langgraph", + "replay_type": getattr(request, "replay_type", "basic"), + }, + ) + finally: + replay_handler.disconnect() + + # --- Chat Model Callbacks --- + + def on_chat_model_start( + self, + serialized: dict[str, Any], + messages: list[list[Any]], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Called when a chat model starts running. + + ChatOpenAI (used by OpenRouter, OpenAI, etc.) triggers this + instead of on_llm_start. We extract the messages and delegate + to the same tracking logic. + """ + if not self._capture_config.is_layer_enabled("model.invoke"): + return + + run_id_str = str(run_id) + model = self._extract_model_name(serialized) + provider = self._extract_provider(serialized) + invocation_params = kwargs.get("invocation_params", {}) + + # Flatten messages to prompt strings for consistent storage + prompts: list[str] = [] + for message_group in messages: + for msg in message_group: + content = getattr(msg, "content", str(msg)) + role = getattr(msg, "type", "unknown") + prompts.append(f"[{role}] {content}") + + self._llm_calls[run_id_str] = LLMCallContext( + run_id=run_id_str, + start_time_ns=time.time_ns(), + model=model, + provider=provider, + prompts=prompts, + invocation_params=invocation_params, + ) + + # --- LLM Callbacks --- + + def on_llm_start( + self, + serialized: dict[str, Any], + prompts: list[str], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Called when LLM starts running.""" + if not self._capture_config.is_layer_enabled("model.invoke"): + return + + run_id_str = str(run_id) + + # Extract model/provider info + model = self._extract_model_name(serialized) + provider = self._extract_provider(serialized) + invocation_params = kwargs.get("invocation_params", {}) + + self._llm_calls[run_id_str] = LLMCallContext( + run_id=run_id_str, + start_time_ns=time.time_ns(), + model=model, + provider=provider, + prompts=prompts, + invocation_params=invocation_params, + ) + + def on_llm_end( + self, + response: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when LLM finishes running.""" + if not self._capture_config.is_layer_enabled("model.invoke"): + return + + run_id_str = str(run_id) + ctx = self._llm_calls.pop(run_id_str, None) + + if ctx is None: + return + + end_time_ns = time.time_ns() + duration_ns = end_time_ns - ctx.start_time_ns + + # Extract response content + output = self._extract_llm_output(response) + token_usage = self._extract_token_usage(response) + + payload = { + "run_id": run_id_str, + "model": {"name": ctx.model or "unknown", "provider": ctx.provider or "unknown"}, + "prompts": ctx.prompts or [], + "output": output, + "token_usage": token_usage, + "duration_ns": duration_ns, + "invocation_params": ctx.invocation_params, + } + + # Attribute to LangGraph node if parent chain is a node + node_name = self._run_to_node.get(str(parent_run_id)) if parent_run_id else None + if node_name: + payload["node_name"] = node_name + + self._emit_event("model.invoke", payload) + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when LLM errors.""" + if not self._capture_config.is_layer_enabled("model.invoke"): + return + + run_id_str = str(run_id) + ctx = self._llm_calls.pop(run_id_str, None) + + if ctx is None: + return + + end_time_ns = time.time_ns() + duration_ns = end_time_ns - ctx.start_time_ns + + payload = { + "run_id": run_id_str, + "model": {"name": ctx.model or "unknown", "provider": ctx.provider or "unknown"}, + "prompts": ctx.prompts or [], + "error": str(error), + "duration_ns": duration_ns, + } + + # Attribute to LangGraph node if parent chain is a node + node_name = self._run_to_node.get(str(parent_run_id)) if parent_run_id else None + if node_name: + payload["node_name"] = node_name + + self._emit_event("model.invoke", payload) + + # --- Tool Callbacks --- + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Called when tool starts running.""" + if not self._capture_config.is_layer_enabled("tool.call"): + return + + run_id_str = str(run_id) + tool_name = serialized.get("name", "unknown_tool") + + self._tool_calls[run_id_str] = ToolCallContext( + run_id=run_id_str, + start_time_ns=time.time_ns(), + tool_name=tool_name, + tool_input=inputs if inputs else input_str, + ) + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when tool finishes running.""" + if not self._capture_config.is_layer_enabled("tool.call"): + return + + run_id_str = str(run_id) + ctx = self._tool_calls.pop(run_id_str, None) + + if ctx is None: + return + + end_time_ns = time.time_ns() + duration_ns = end_time_ns - ctx.start_time_ns + + payload = { + "run_id": run_id_str, + "tool_name": ctx.tool_name, + "input": ctx.tool_input, + "output": output, + "duration_ns": duration_ns, + } + + # Attribute to LangGraph node if parent chain is a node + node_name = self._run_to_node.get(str(parent_run_id)) if parent_run_id else None + if node_name: + payload["node_name"] = node_name + + self._emit_event("tool.call", payload) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when tool errors.""" + if not self._capture_config.is_layer_enabled("tool.call"): + return + + run_id_str = str(run_id) + ctx = self._tool_calls.pop(run_id_str, None) + + if ctx is None: + return + + end_time_ns = time.time_ns() + duration_ns = end_time_ns - ctx.start_time_ns + + payload = { + "run_id": run_id_str, + "tool_name": ctx.tool_name, + "input": ctx.tool_input, + "error": str(error), + "duration_ns": duration_ns, + } + + # Attribute to LangGraph node if parent chain is a node + node_name = self._run_to_node.get(str(parent_run_id)) if parent_run_id else None + if node_name: + payload["node_name"] = node_name + + self._emit_event("tool.call", payload) + + # --- Agent Callbacks --- + + def on_agent_action( + self, + action: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when agent takes an action.""" + if not self._capture_config.is_layer_enabled("agent.input"): + return + + run_id_str = str(run_id) + + # Extract action details + action_str = ( + getattr(action, "tool", str(action)) if hasattr(action, "tool") else str(action) + ) + action_input = getattr(action, "tool_input", None) + + self._agent_actions[run_id_str] = AgentActionContext( + run_id=run_id_str, + start_time_ns=time.time_ns(), + action=action_str, + action_input=action_input, + ) + + self._emit_event( + "tool.call", + { + "run_id": run_id_str, + "tool_name": action_str, + "tool_input": action_input, + }, + ) + + def on_agent_finish( + self, + finish: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when agent finishes.""" + if not self._capture_config.is_layer_enabled("agent.output"): + return + + run_id_str = str(run_id) + + # Extract output + output = getattr(finish, "return_values", str(finish)) + log = getattr(finish, "log", None) + + self._emit_event( + "agent.output", + { + "run_id": run_id_str, + "output": output, + "log": log, + }, + ) + + # --- Chain Callbacks --- + + def on_chain_start( + self, + serialized: dict[str, Any], + inputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Called when chain starts running. + + For LangGraph node executions, metadata contains 'langgraph_node' + with the node name. We emit agent.input and track the run_id so + child LLM/tool calls can be attributed to the node. + """ + run_id_str = str(run_id) + parent_id_str = str(parent_run_id) if parent_run_id else None + meta = metadata or {} + + node_name = meta.get("langgraph_node") + + if node_name: + # This is a LangGraph node execution + self._chain_calls[run_id_str] = ChainCallContext( + run_id=run_id_str, + start_time_ns=time.time_ns(), + node_name=node_name, + parent_run_id=parent_id_str, + ) + self._run_to_node[run_id_str] = node_name + + if self._capture_config.is_layer_enabled("agent.input"): + input_summary = str(inputs)[:500] if inputs else None + self._emit_event( + "agent.input", + { + "run_id": run_id_str, + "node_name": node_name, + "input": input_summary, + "langgraph_step": meta.get("langgraph_step"), + "langgraph_triggers": meta.get("langgraph_triggers"), + }, + ) + elif parent_id_str and parent_id_str in self._run_to_node: + # Sub-chain within a LangGraph node — inherit the node mapping + inherited_node = self._run_to_node[parent_id_str] + self._run_to_node[run_id_str] = inherited_node + self._chain_calls[run_id_str] = ChainCallContext( + run_id=run_id_str, + start_time_ns=time.time_ns(), + node_name=inherited_node, + parent_run_id=parent_id_str, + ) + + def on_chain_end( + self, + outputs: dict[str, Any], + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when chain finishes running.""" + run_id_str = str(run_id) + ctx = self._chain_calls.pop(run_id_str, None) + self._run_to_node.pop(run_id_str, None) + + if ctx is None or ctx.node_name is None: + return + + if not self._capture_config.is_layer_enabled("agent.output"): + return + + end_time_ns = time.time_ns() + duration_ns = end_time_ns - ctx.start_time_ns + + output_summary = str(outputs)[:500] if outputs else None + self._emit_event( + "agent.output", + { + "run_id": run_id_str, + "node_name": ctx.node_name, + "output": output_summary, + "duration_ns": duration_ns, + }, + ) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + **kwargs: Any, + ) -> None: + """Called when chain errors.""" + run_id_str = str(run_id) + ctx = self._chain_calls.pop(run_id_str, None) + self._run_to_node.pop(run_id_str, None) + + if ctx is None or ctx.node_name is None: + return + + if not self._capture_config.is_layer_enabled("agent.output"): + return + + end_time_ns = time.time_ns() + duration_ns = end_time_ns - ctx.start_time_ns + + self._emit_event( + "agent.output", + { + "run_id": run_id_str, + "node_name": ctx.node_name, + "error": str(error), + "duration_ns": duration_ns, + }, + ) + + # --- Helper Methods --- + + def _extract_model_name(self, serialized: dict[str, Any]) -> str: + """Extract model name from serialized LLM.""" + for key in ["model_name", "model", "name"]: + if key in serialized: + return serialized[key] # type: ignore[no-any-return] + + kwargs = serialized.get("kwargs", {}) + for key in ["model_name", "model"]: + if key in kwargs: + return kwargs[key] # type: ignore[no-any-return] + + return "unknown" + + def _extract_provider(self, serialized: dict[str, Any]) -> str: + """Extract provider from serialized LLM.""" + id_parts = serialized.get("id", ["unknown"]) + if isinstance(id_parts, list) and len(id_parts) >= 3: + return id_parts[2] if len(id_parts) > 2 else "unknown" + + name = serialized.get("name", "").lower() + if "openai" in name: + return "openai" + elif "anthropic" in name or "claude" in name: + return "anthropic" + elif "google" in name or "gemini" in name: + return "google" + + return "unknown" + + def _extract_llm_output(self, response: Any) -> Any: + """Extract output from LLM response.""" + if hasattr(response, "generations"): + generations = response.generations + if generations and len(generations) > 0: + gen = generations[0] + if isinstance(gen, list) and len(gen) > 0: + return gen[0].text if hasattr(gen[0], "text") else str(gen[0]) + return gen.text if hasattr(gen, "text") else str(gen) + + return str(response) + + def _extract_token_usage(self, response: Any) -> dict[str, int] | None: + """Extract token usage from response.""" + if hasattr(response, "llm_output") and response.llm_output: + return response.llm_output.get("token_usage") # type: ignore[no-any-return] + return None + + def _emit_event(self, event_type: str, payload: dict[str, Any]) -> None: + """Emit an STRATIX event through BaseAdapter's circuit-breaker path.""" + event = {"type": event_type, "payload": payload} + self._events.append(event) + self.emit_dict_event(event_type, payload) + + # --- Testing/Debugging --- + + def get_events(self, event_type: str | None = None) -> list[dict[str, Any]]: + """Get recorded events (useful for testing).""" + if event_type: + return [e for e in self._events if e["type"] == event_type] + return self._events + + def clear_events(self) -> None: + """Clear recorded events.""" + self._events.clear() + + @staticmethod + def _detect_framework_version() -> str | None: + try: + import langchain # type: ignore[import-not-found,unused-ignore] + + return getattr(langchain, "__version__", None) + except ImportError: + return None diff --git a/src/layerlens/instrument/adapters/frameworks/langchain/chains.py b/src/layerlens/instrument/adapters/frameworks/langchain/chains.py new file mode 100644 index 00000000..df78f531 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain/chains.py @@ -0,0 +1,281 @@ +""" +STRATIX LangChain Chain Instrumentation + +Provides automatic instrumentation for LangChain chains. +""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any +from dataclasses import dataclass + +from layerlens.instrument.adapters.frameworks.langchain.callbacks import LayerLensCallbackHandler + +if TYPE_CHECKING: + from layerlens.instrument.adapters._base.adapter import BaseAdapter + + +@dataclass +class ChainExecution: + """Tracks a single chain execution.""" + + chain_type: str + start_time_ns: int + end_time_ns: int | None = None + inputs: dict[str, Any] | None = None + outputs: dict[str, Any] | None = None + error: str | None = None + + +class TracedChain: + """ + Wrapper around a LangChain chain with STRATIX tracing. + + Automatically injects LayerLensCallbackHandler and tracks + chain executions. + + Usage: + from langchain.chains import LLMChain # type: ignore[import-untyped,unused-ignore] + + chain = LLMChain(llm=llm, prompt=prompt) + traced_chain = TracedChain(chain, stratix_instance) + + # Use as normal + result = traced_chain.invoke({"input": "hello"}) + """ + + def __init__( + self, + chain: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + ) -> None: + """ + Initialize the traced chain. + + Args: + chain: LangChain chain instance + stratix_instance: STRATIX SDK instance (legacy) + adapter: BaseAdapter instance (new-style) + """ + self._chain = chain + self._stratix = stratix_instance + self._adapter = adapter + self._handler = LayerLensCallbackHandler( + stratix=adapter._stratix if adapter else None, + capture_config=adapter.capture_config if adapter else None, + stratix_instance=stratix_instance, + ) + self._chain_type = type(chain).__name__ + self._executions: list[ChainExecution] = [] + + def invoke( + self, + input: dict[str, Any], + config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Invoke the chain with tracing. + + Args: + input: Input dictionary + config: Optional config + **kwargs: Additional arguments + + Returns: + Chain output + """ + execution = ChainExecution( + chain_type=self._chain_type, + start_time_ns=time.time_ns(), + inputs=input, + ) + self._executions.append(execution) + + # Inject callback handler + callbacks = kwargs.get("callbacks", []) + if self._handler not in callbacks: + callbacks = list(callbacks) + [self._handler] + kwargs["callbacks"] = callbacks + + try: + # Execute chain + result = self._chain.invoke(input, config, **kwargs) + + execution.end_time_ns = time.time_ns() + execution.outputs = result if isinstance(result, dict) else {"output": result} + + # Emit chain completion event + self._emit_chain_event(execution) + + return result # type: ignore[no-any-return] + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._emit_chain_event(execution) + raise + + async def ainvoke( + self, + input: dict[str, Any], + config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Async invoke the chain with tracing. + + Args: + input: Input dictionary + config: Optional config + **kwargs: Additional arguments + + Returns: + Chain output + """ + execution = ChainExecution( + chain_type=self._chain_type, + start_time_ns=time.time_ns(), + inputs=input, + ) + self._executions.append(execution) + + # Inject callback handler + callbacks = kwargs.get("callbacks", []) + if self._handler not in callbacks: + callbacks = list(callbacks) + [self._handler] + kwargs["callbacks"] = callbacks + + try: + result = await self._chain.ainvoke(input, config, **kwargs) + + execution.end_time_ns = time.time_ns() + execution.outputs = result if isinstance(result, dict) else {"output": result} + + self._emit_chain_event(execution) + + return result # type: ignore[no-any-return] + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._emit_chain_event(execution) + raise + + def run(self, *args: Any, **kwargs: Any) -> str: + """ + Run the chain (deprecated LangChain method). + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Chain output string + """ + # Inject callback + callbacks = kwargs.get("callbacks", []) + if self._handler not in callbacks: + callbacks = list(callbacks) + [self._handler] + kwargs["callbacks"] = callbacks + + return self._chain.run(*args, **kwargs) # type: ignore[no-any-return] + + def _emit_chain_event(self, execution: ChainExecution) -> None: + """Emit chain execution event.""" + duration_ns = (execution.end_time_ns or 0) - execution.start_time_ns + payload = { + "chain_type": execution.chain_type, + "inputs": execution.inputs, + "outputs": execution.outputs, + "duration_ns": duration_ns, + "error": execution.error, + } + + # New-style: route through adapter's circuit-breaker path + if self._adapter is not None: + self._adapter.emit_dict_event("chain.execution", payload) + return + + # Legacy + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit("chain.execution", payload) + + @property + def callback_handler(self) -> LayerLensCallbackHandler: + """Get the callback handler.""" + return self._handler + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to underlying chain.""" + return getattr(self._chain, name) + + +def instrument_chain( + chain: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, +) -> TracedChain: + """ + Instrument a LangChain chain with STRATIX tracing. + + Args: + chain: LangChain chain instance + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + + Returns: + TracedChain wrapper + """ + return TracedChain(chain, stratix_instance, adapter=adapter) + + +class ChainTracer: + """ + Tracer for multiple chain executions. + + Useful for tracking chains in a larger workflow. + """ + + def __init__(self, stratix_instance: Any = None, adapter: BaseAdapter | None = None) -> None: + """ + Initialize the chain tracer. + + Args: + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + """ + self._stratix = stratix_instance + self._adapter = adapter + self._handler = LayerLensCallbackHandler( + stratix=adapter._stratix if adapter else None, + capture_config=adapter.capture_config if adapter else None, + stratix_instance=stratix_instance, + ) + self._chains: dict[str, TracedChain] = {} + + def trace(self, chain: Any, name: str | None = None) -> TracedChain: + """ + Start tracing a chain. + + Args: + chain: LangChain chain + name: Optional name for the chain + + Returns: + TracedChain wrapper + """ + chain_name = name or type(chain).__name__ + traced = TracedChain(chain, self._stratix, adapter=self._adapter) + self._chains[chain_name] = traced + return traced + + def get_events(self, event_type: str | None = None) -> list[dict[str, Any]]: + """Get all events from the callback handler.""" + return self._handler.get_events(event_type) + + def get_chain(self, name: str) -> TracedChain | None: + """Get a traced chain by name.""" + return self._chains.get(name) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain/memory.py b/src/layerlens/instrument/adapters/frameworks/langchain/memory.py new file mode 100644 index 00000000..1e9c9075 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain/memory.py @@ -0,0 +1,293 @@ +""" +STRATIX LangChain Memory Tracing + +Wraps LangChain memory to emit agent.state.change events. +""" + +from __future__ import annotations + +import json +import time +from typing import Any + +from layerlens.instrument.adapters.frameworks.langchain.state import LangChainMemoryAdapter + + +class TracedMemory: + """ + Wrapper around LangChain memory that emits state change events. + + Proxies all calls to the underlying memory while tracking changes + and emitting agent.state.change events. + + Usage: + from langchain.memory import ConversationBufferMemory # type: ignore[import-untyped,unused-ignore] + + memory = ConversationBufferMemory() + traced_memory = TracedMemory(memory, stratix_instance) + + # Use as normal + traced_memory.save_context({"input": "hello"}, {"output": "hi"}) + """ + + def __init__( + self, + memory: Any, + stratix_instance: Any = None, + memory_service: Any | None = None, + ) -> None: + """ + Initialize the traced memory. + + Args: + memory: LangChain memory instance + stratix_instance: STRATIX SDK instance + memory_service: Optional AgentMemoryService for episodic memory storage. + When provided, save_context() will persist a summary of each + interaction as an episodic memory entry. + """ + self._memory = memory + self._stratix = stratix_instance + self._memory_service = memory_service + self._adapter = LangChainMemoryAdapter(memory) + self._last_hash: str | None = None + + def save_context( + self, + inputs: dict[str, Any], + outputs: dict[str, str], + ) -> None: + """ + Save context to memory with state change tracking. + + Args: + inputs: Input dictionary + outputs: Output dictionary + """ + # Snapshot before + before_hash = self._adapter.get_hash() + + # Call underlying memory + self._memory.save_context(inputs, outputs) + + # Snapshot after + after_hash = self._adapter.get_hash() + + # Emit state change if changed + if before_hash != after_hash: + self._emit_state_change(before_hash, after_hash, "save_context") + + # Store episodic memory if memory_service is provided + if self._memory_service is not None: + self._store_episodic_memory(inputs, outputs) + + self._last_hash = after_hash + + def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: + """ + Load memory variables. + + Args: + inputs: Input dictionary + + Returns: + Memory variables + """ + return self._memory.load_memory_variables(inputs) # type: ignore[no-any-return] + + def clear(self) -> None: + """Clear memory with state change tracking.""" + before_hash = self._adapter.get_hash() + + self._memory.clear() + + after_hash = self._adapter.get_hash() + + if before_hash != after_hash: + self._emit_state_change(before_hash, after_hash, "clear") + + self._last_hash = after_hash + + def _emit_state_change( + self, + before_hash: str, + after_hash: str, + trigger: str, + ) -> None: + """Emit agent.state.change event.""" + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit( + "agent.state.change", + { + "memory_type": type(self._memory).__name__, + "before_hash": before_hash, + "after_hash": after_hash, + "trigger": trigger, + "timestamp_ns": time.time_ns(), + }, + ) + + def _store_episodic_memory( + self, + inputs: dict[str, Any], + outputs: dict[str, str], + ) -> None: + """Store a conversation turn as episodic memory via AgentMemoryService. + + Only called when ``memory_service`` was provided at construction time. + Failures are logged and swallowed to avoid disrupting normal operation. + """ + try: + from layerlens.instrument._vendored.memory_models import MemoryEntry + + timestamp = int(time.time()) + summary = json.dumps( + {"inputs": inputs, "outputs": outputs}, + default=str, + ) + entry = MemoryEntry( + org_id=getattr(self._stratix, "org_id", ""), + agent_id=getattr(self._stratix, "agent_id", "langchain"), + memory_type="episodic", + key=f"conversation_{timestamp}", + content=summary, + importance=0.5, + metadata={"source": "langchain_traced_memory"}, + ) + self._memory_service.store(entry) # type: ignore[union-attr] + except Exception: + import logging + + logging.getLogger(__name__).debug( + "Failed to store episodic memory from save_context", + exc_info=True, + ) + + @property + def memory_variables(self) -> list[str]: + """Get memory variable names.""" + return self._memory.memory_variables # type: ignore[no-any-return] + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to underlying memory.""" + return getattr(self._memory, name) + + +def wrap_memory( + memory: Any, + stratix_instance: Any = None, +) -> TracedMemory: + """ + Wrap a LangChain memory instance with STRATIX tracing. + + Args: + memory: LangChain memory instance + stratix_instance: STRATIX SDK instance + + Returns: + TracedMemory wrapper + """ + return TracedMemory(memory, stratix_instance) + + +class MemoryMutationTracker: + """ + Tracks memory mutations for a conversation. + + Useful for tracking all memory changes across multiple + LangChain invocations. + """ + + def __init__(self, stratix_instance: Any = None) -> None: + """ + Initialize the mutation tracker. + + Args: + stratix_instance: STRATIX SDK instance + """ + self._stratix = stratix_instance + self._mutations: list[dict[str, Any]] = [] + + def track_memory( + self, + memory: Any, + operation: str = "unknown", + ) -> Any: + """ + Create a context manager to track memory changes. + + Args: + memory: LangChain memory instance + operation: Description of the operation + + Returns: + Context manager + """ + return _MemoryTrackingContext( + memory=memory, + operation=operation, + tracker=self, + stratix=self._stratix, + ) + + def record_mutation(self, mutation: dict[str, Any]) -> None: + """Record a mutation.""" + self._mutations.append(mutation) + + def get_mutations(self) -> list[dict[str, Any]]: + """Get all recorded mutations.""" + return self._mutations + + def clear(self) -> None: + """Clear recorded mutations.""" + self._mutations.clear() + + +class _MemoryTrackingContext: + """Context manager for tracking memory changes.""" + + def __init__( + self, + memory: Any, + operation: str, + tracker: MemoryMutationTracker, + stratix: Any, + ) -> None: + self._memory = memory + self._operation = operation + self._tracker = tracker + self._stratix = stratix + self._adapter = LangChainMemoryAdapter(memory) + self._before_snapshot = None + + def __enter__(self) -> _MemoryTrackingContext: + self._before_snapshot = self._adapter.snapshot() # type: ignore[assignment] + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + after_snapshot = self._adapter.snapshot() + + if self._adapter.has_changed(self._before_snapshot, after_snapshot): # type: ignore[arg-type] + diff = self._adapter.diff(self._before_snapshot, after_snapshot) # type: ignore[arg-type] + + mutation = { + "operation": self._operation, + "before_hash": self._before_snapshot.hash, # type: ignore[attr-defined] + "after_hash": after_snapshot.hash, + "diff": diff, + "timestamp_ns": time.time_ns(), + } + + self._tracker.record_mutation(mutation) + + # Emit event + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit( + "agent.state.change", + { + "memory_type": self._before_snapshot.memory_type, # type: ignore[attr-defined] + "before_hash": self._before_snapshot.hash, # type: ignore[attr-defined] + "after_hash": after_snapshot.hash, + "operation": self._operation, + }, + ) diff --git a/src/layerlens/instrument/adapters/frameworks/langchain/state.py b/src/layerlens/instrument/adapters/frameworks/langchain/state.py new file mode 100644 index 00000000..897b64a8 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain/state.py @@ -0,0 +1,205 @@ +""" +STRATIX LangChain Memory State Adapter + +Adapts LangChain memory for STRATIX state tracking. +""" + +from __future__ import annotations + +import json +import hashlib +from typing import Any +from dataclasses import dataclass + + +@dataclass +class MemorySnapshot: + """Snapshot of memory state at a point in time.""" + + memory_type: str + variables: dict[str, Any] + hash: str + timestamp_ns: int + message_count: int | None = None + + +class LangChainMemoryAdapter: + """ + State adapter for LangChain memory classes. + + Supports various LangChain memory types: + - ConversationBufferMemory + - ConversationSummaryMemory + - ConversationBufferWindowMemory + - Entity memory, etc. + + Usage: + from langchain.memory import ConversationBufferMemory # type: ignore[import-untyped,unused-ignore] + + memory = ConversationBufferMemory() + adapter = LangChainMemoryAdapter(memory) + + # Take snapshot + before = adapter.snapshot() + + # ... use memory ... + + # Check for changes + after = adapter.snapshot() + if adapter.has_changed(before, after): + print("Memory changed!") + """ + + def __init__(self, memory: Any) -> None: + """ + Initialize the memory adapter. + + Args: + memory: LangChain memory instance + """ + self._memory = memory + self._memory_type = type(memory).__name__ + + def snapshot(self) -> MemorySnapshot: + """ + Create a snapshot of the current memory state. + + Returns: + MemorySnapshot with hash for comparison + """ + import time + + # Get memory variables + variables = self._get_memory_variables() + + # Count messages if applicable + message_count = self._count_messages() + + # Compute hash + hash_value = self._compute_hash(variables) + + return MemorySnapshot( + memory_type=self._memory_type, + variables=variables, + hash=hash_value, + timestamp_ns=time.time_ns(), + message_count=message_count, + ) + + def has_changed(self, before: MemorySnapshot, after: MemorySnapshot) -> bool: + """ + Check if memory has changed between snapshots. + + Args: + before: Snapshot before operation + after: Snapshot after operation + + Returns: + True if memory changed + """ + return before.hash != after.hash + + def diff(self, before: MemorySnapshot, after: MemorySnapshot) -> dict[str, Any]: + """ + Compute the difference between two snapshots. + + Args: + before: Snapshot before operation + after: Snapshot after operation + + Returns: + Dictionary describing changes + """ + added = {} + removed = {} + modified = {} + + before_vars = before.variables + after_vars = after.variables + + before_keys = set(before_vars.keys()) + after_keys = set(after_vars.keys()) + + # Added variables + for key in after_keys - before_keys: + added[key] = after_vars[key] + + # Removed variables + for key in before_keys - after_keys: + removed[key] = before_vars[key] + + # Modified variables + for key in before_keys & after_keys: + if before_vars[key] != after_vars[key]: + modified[key] = { + "before": before_vars[key], + "after": after_vars[key], + } + + # Message diff if applicable + messages_added = None + if before.message_count is not None and after.message_count is not None: # noqa: SIM102 + if after.message_count > before.message_count: + messages_added = after.message_count - before.message_count + + return { + "added": added, + "removed": removed, + "modified": modified, + "messages_added": messages_added, + } + + def get_hash(self) -> str: + """ + Get current memory hash without creating full snapshot. + + Returns: + Hash string + """ + variables = self._get_memory_variables() + return self._compute_hash(variables) + + def _get_memory_variables(self) -> dict[str, Any]: + """Get memory variables dictionary.""" + # Use load_memory_variables if available + if hasattr(self._memory, "load_memory_variables"): + try: + return dict(self._memory.load_memory_variables({})) + except Exception: + pass + + # Fallback to chat_memory.messages + if hasattr(self._memory, "chat_memory"): + messages = getattr(self._memory.chat_memory, "messages", []) + return {"messages": [self._serialize_message(m) for m in messages]} + + # Fallback to buffer attribute + if hasattr(self._memory, "buffer"): + return {"buffer": self._memory.buffer} + + return {} + + def _count_messages(self) -> int | None: + """Count messages in memory.""" + if hasattr(self._memory, "chat_memory"): + messages = getattr(self._memory.chat_memory, "messages", []) + return len(messages) + return None + + def _serialize_message(self, message: Any) -> dict[str, Any]: + """Serialize a message for hashing.""" + if hasattr(message, "content") and hasattr(message, "type"): + return { + "type": message.type, + "content": message.content, + } + return {"content": str(message)} + + def _compute_hash(self, variables: dict[str, Any]) -> str: + """Compute SHA-256 hash of memory state.""" + try: + serialized = json.dumps(variables, sort_keys=True, default=str) + except TypeError: + serialized = str(variables) + + return hashlib.sha256(serialized.encode()).hexdigest() diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/__init__.py b/src/layerlens/instrument/adapters/frameworks/langfuse/__init__.py new file mode 100644 index 00000000..fb28a0fb --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/__init__.py @@ -0,0 +1,46 @@ +""" +STRATIX Langfuse Adapter + +Bidirectional trace sync between STRATIX and Langfuse. + +Unlike other adapters that wrap running code in real-time, the Langfuse +adapter is a data import/export pipeline that communicates with a remote +Langfuse HTTP API to pull/push traces in batch. + +Usage: + from layerlens.instrument.adapters.frameworks.langfuse import LangfuseAdapter + from layerlens.instrument.adapters.frameworks.langfuse.config import LangfuseConfig + + config = LangfuseConfig( + public_key="pk-...", + secret_key="sk-...", + ) + + adapter = LangfuseAdapter(stratix=stratix_instance, config=config) + adapter.connect() + + # Import traces from Langfuse + result = adapter.import_traces(since=datetime(2024, 1, 1)) + + # Export STRATIX traces to Langfuse + result = adapter.export_traces(events_by_trace={"trace-1": [...]}) +""" + +from __future__ import annotations + +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat, requires_pydantic + +# Round-2 deliberation item 20: ``frameworks/langfuse/config.py`` uses +# ``field_validator`` (v2-only); fail fast under v1 with a clear message +# instead of a confusing ImportError from config.py. +requires_pydantic(PydanticCompat.V2_ONLY) + +from layerlens.instrument.adapters.frameworks.langfuse.lifecycle import LangfuseAdapter + +# Registry lazy-loading convention +ADAPTER_CLASS = LangfuseAdapter + +__all__ = [ + "LangfuseAdapter", + "ADAPTER_CLASS", +] diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/client.py b/src/layerlens/instrument/adapters/frameworks/langfuse/client.py new file mode 100644 index 00000000..def3e8ea --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/client.py @@ -0,0 +1,293 @@ +""" +Langfuse API Client + +HTTP client for the Langfuse REST API using stdlib urllib. +Supports Basic auth, pagination, and exponential backoff. +""" + +from __future__ import annotations + +import json +import time +import base64 +import logging +import contextlib +from typing import Any +from datetime import datetime, timezone + +UTC = timezone.utc # Python 3.11+ has datetime.UTC; alias for 3.9/3.10 compat. +from urllib.error import URLError, HTTPError +from urllib.request import Request, urlopen + +logger = logging.getLogger(__name__) + +# Langfuse API rate limit: 429 responses trigger backoff +_DEFAULT_MAX_RETRIES = 3 +_BACKOFF_BASE_S = 1.0 +_BACKOFF_MAX_S = 16.0 +_REQUEST_TIMEOUT_S = 30 + + +class LangfuseAPIError(Exception): + """Raised when a Langfuse API call fails.""" + + def __init__(self, message: str, status_code: int | None = None, body: str = "") -> None: + super().__init__(message) + self.status_code = status_code + self.body = body + + +class LangfuseAPIClient: + """ + HTTP client for the Langfuse REST API. + + Uses Basic auth with base64(public_key:secret_key). + No external dependencies — built on stdlib urllib.request. + """ + + def __init__( + self, + public_key: str, + secret_key: str, + host: str = "https://cloud.langfuse.com", + max_retries: int = _DEFAULT_MAX_RETRIES, + timeout: int = _REQUEST_TIMEOUT_S, + ) -> None: + self._host = host.rstrip("/") + self._max_retries = max_retries + self._timeout = timeout + + # Basic auth header + credentials = f"{public_key}:{secret_key}" + encoded = base64.b64encode(credentials.encode()).decode() + self._auth_header = f"Basic {encoded}" + + # --- Public API --- + + def health_check(self) -> dict[str, Any]: + """Check Langfuse API health.""" + return self._request("GET", "/api/public/health") + + def list_traces( + self, + page: int = 1, + limit: int = 50, + order_by: str = "timestamp", + order: str = "DESC", + name: str | None = None, + tags: list[str] | None = None, + from_timestamp: datetime | None = None, + to_timestamp: datetime | None = None, + ) -> dict[str, Any]: + """ + List traces with pagination and filtering. + + Returns dict with 'data' (list of trace objects) and 'meta' (pagination info). + """ + params: dict[str, Any] = { + "page": page, + "limit": limit, + "orderBy": order_by, + "order": order, + } + if name: + params["name"] = name + if tags: + for tag in tags: + params.setdefault("tags", []).append(tag) + if from_timestamp: + params["fromTimestamp"] = from_timestamp.isoformat() + if to_timestamp: + params["toTimestamp"] = to_timestamp.isoformat() + + return self._request("GET", "/api/public/traces", params=params) + + def get_trace(self, trace_id: str) -> dict[str, Any]: + """Get a single trace with all observations.""" + return self._request("GET", f"/api/public/traces/{trace_id}") + + def list_observations( + self, + trace_id: str | None = None, + page: int = 1, + limit: int = 50, + type: str | None = None, + ) -> dict[str, Any]: + """List observations for a trace.""" + params: dict[str, Any] = {"page": page, "limit": limit} + if trace_id: + params["traceId"] = trace_id + if type: + params["type"] = type + return self._request("GET", "/api/public/observations", params=params) + + def create_trace(self, trace_data: dict[str, Any]) -> dict[str, Any]: + """Create a new trace in Langfuse.""" + return self._request( + "POST", + "/api/public/ingestion", + body={ + "batch": [ + { + "id": trace_data.get("id", ""), + "type": "trace-create", + "timestamp": datetime.now(UTC).isoformat(), + "body": trace_data, + } + ], + }, + ) + + def create_generation(self, generation_data: dict[str, Any]) -> dict[str, Any]: + """Create a generation observation.""" + return self._request( + "POST", + "/api/public/ingestion", + body={ + "batch": [ + { + "id": generation_data.get("id", ""), + "type": "generation-create", + "timestamp": datetime.now(UTC).isoformat(), + "body": generation_data, + } + ], + }, + ) + + def create_span(self, span_data: dict[str, Any]) -> dict[str, Any]: + """Create a span observation.""" + return self._request( + "POST", + "/api/public/ingestion", + body={ + "batch": [ + { + "id": span_data.get("id", ""), + "type": "span-create", + "timestamp": datetime.now(UTC).isoformat(), + "body": span_data, + } + ], + }, + ) + + def ingestion_batch(self, events: list[dict[str, Any]]) -> dict[str, Any]: + """Send a batch of ingestion events.""" + return self._request("POST", "/api/public/ingestion", body={"batch": events}) + + def get_all_traces( + self, + limit: int = 50, + tags: list[str] | None = None, + from_timestamp: datetime | None = None, + to_timestamp: datetime | None = None, + ) -> list[dict[str, Any]]: + """ + Fetch all traces with automatic pagination. + + Yields all pages until exhausted. + """ + all_traces: list[dict[str, Any]] = [] + page = 1 + while True: + result = self.list_traces( + page=page, + limit=limit, + tags=tags, + from_timestamp=from_timestamp, + to_timestamp=to_timestamp, + ) + data = result.get("data", []) + if not data: + break + all_traces.extend(data) + meta = result.get("meta", {}) + total_pages = meta.get("totalPages", 1) + if page >= total_pages: + break + page += 1 + return all_traces + + # --- Internal --- + + def _request( + self, + method: str, + path: str, + params: dict[str, Any] | None = None, + body: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Make an HTTP request with retry and backoff.""" + url = f"{self._host}{path}" + if params: + # Handle list params (e.g., tags) + query_parts = [] + for k, v in params.items(): + if isinstance(v, list): + for item in v: + query_parts.append(f"{k}={item}") + else: + query_parts.append(f"{k}={v}") + url = f"{url}?{'&'.join(query_parts)}" + + headers = { + "Authorization": self._auth_header, + "Content-Type": "application/json", + "Accept": "application/json", + } + + data = json.dumps(body).encode() if body else None + + last_error: Exception | None = None + for attempt in range(self._max_retries + 1): + try: + req = Request(url, data=data, headers=headers, method=method) + with urlopen(req, timeout=self._timeout) as resp: + resp_body = resp.read().decode() + if not resp_body: + return {} + return json.loads(resp_body) # type: ignore[no-any-return] + + except HTTPError as e: + status = e.code + error_body = "" + with contextlib.suppress(Exception): + error_body = e.read().decode() + + if status == 429 or status >= 500: + last_error = LangfuseAPIError( + f"HTTP {status}: {error_body}", status_code=status, body=error_body + ) + if attempt < self._max_retries: + delay = min(_BACKOFF_BASE_S * (2**attempt), _BACKOFF_MAX_S) + logger.debug( + "Langfuse API %s %s returned %d, retrying in %.1fs (attempt %d/%d)", + method, + path, + status, + delay, + attempt + 1, + self._max_retries, + ) + time.sleep(delay) + continue + raise LangfuseAPIError( # noqa: B904 + f"HTTP {status}: {error_body}", status_code=status, body=error_body + ) + + except URLError as e: + last_error = LangfuseAPIError(f"Connection error: {e}") + if attempt < self._max_retries: + delay = min(_BACKOFF_BASE_S * (2**attempt), _BACKOFF_MAX_S) + logger.debug( + "Langfuse API connection error, retrying in %.1fs (attempt %d/%d)", + delay, + attempt + 1, + self._max_retries, + ) + time.sleep(delay) + continue + raise last_error # noqa: B904 + + raise last_error or LangfuseAPIError("Max retries exceeded") diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/config.py b/src/layerlens/instrument/adapters/frameworks/langfuse/config.py new file mode 100644 index 00000000..afa6e835 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/config.py @@ -0,0 +1,143 @@ +""" +Langfuse Adapter Configuration Models + +Pydantic models for Langfuse adapter configuration, sync state tracking, +and sync result reporting. +""" + +from __future__ import annotations + +from enum import Enum # Python 3.11+ has StrEnum; using `(str, Enum)` for 3.9/3.10 compat. +from typing import Optional +from datetime import datetime + +from pydantic import Field, BaseModel, field_validator + + +class SyncDirection(str, Enum): + """Direction of synchronization.""" + + IMPORT = "import" + EXPORT = "export" + BIDIRECTIONAL = "bidirectional" + + +class ConflictStrategy(str, Enum): + """Strategy for resolving sync conflicts.""" + + LAST_WRITE_WINS = "last-write-wins" + MANUAL = "manual" + + +class LangfuseConfig(BaseModel): + """Configuration for the Langfuse adapter.""" + + public_key: str = Field(description="Langfuse public API key") + secret_key: str = Field(description="Langfuse secret API key") + host: str = Field( + default="https://cloud.langfuse.com", + description="Langfuse API host URL", + ) + mode: SyncDirection = Field( + default=SyncDirection.IMPORT, + description="Sync mode: import, export, or bidirectional", + ) + sync_interval_seconds: int = Field( + default=3600, + description="Auto-sync interval in seconds (0 = disabled)", + ) + project_filter: Optional[str] = Field( + default=None, + description="Filter by Langfuse project name", + ) + tag_filter: Optional[list[str]] = Field( + default=None, + description="Filter by trace tags", + ) + since: Optional[datetime] = Field( + default=None, + description="Only sync traces after this timestamp", + ) + conflict_strategy: ConflictStrategy = Field( + default=ConflictStrategy.LAST_WRITE_WINS, + description="Conflict resolution strategy", + ) + max_retries: int = Field(default=3, description="Max retries per API call") + page_size: int = Field(default=50, description="Page size for listing traces") + + @field_validator("host") + @classmethod + def strip_trailing_slash(cls, v: str) -> str: + return v.rstrip("/") + + +class SyncState(BaseModel): + """Tracks the state of a Langfuse sync session.""" + + last_import_cursor: Optional[datetime] = Field( + default=None, + description="Timestamp of the last imported trace", + ) + last_export_cursor: Optional[datetime] = Field( + default=None, + description="Timestamp of the last exported trace", + ) + imported_trace_ids: set[str] = Field( + default_factory=set, + description="Set of Langfuse trace IDs that have been imported", + ) + exported_trace_ids: set[str] = Field( + default_factory=set, + description="Set of STRATIX trace IDs that have been exported", + ) + quarantined_trace_ids: dict[str, int] = Field( + default_factory=dict, + description="Trace IDs that have failed repeatedly, mapped to failure count", + ) + + def record_import(self, trace_id: str, updated_at: datetime) -> None: + """Record a successful import.""" + self.imported_trace_ids.add(trace_id) + if self.last_import_cursor is None or updated_at > self.last_import_cursor: + self.last_import_cursor = updated_at + # Clear from quarantine on success + self.quarantined_trace_ids.pop(trace_id, None) + + def record_export(self, trace_id: str, updated_at: datetime) -> None: + """Record a successful export.""" + self.exported_trace_ids.add(trace_id) + if self.last_export_cursor is None or updated_at > self.last_export_cursor: + self.last_export_cursor = updated_at + + def record_failure(self, trace_id: str, max_failures: int = 3) -> bool: + """ + Record a failure for a trace. Returns True if the trace is now quarantined. + """ + count = self.quarantined_trace_ids.get(trace_id, 0) + 1 + self.quarantined_trace_ids[trace_id] = count + return count >= max_failures + + def is_quarantined(self, trace_id: str) -> bool: + """Check if a trace is quarantined (3+ failures).""" + return self.quarantined_trace_ids.get(trace_id, 0) >= 3 + + def clear_quarantine(self, trace_id: str | None = None) -> None: + """Clear quarantine for a specific trace or all traces.""" + if trace_id: + self.quarantined_trace_ids.pop(trace_id, None) + else: + self.quarantined_trace_ids.clear() + + +class SyncResult(BaseModel): + """Result of a sync operation.""" + + direction: SyncDirection + imported_count: int = Field(default=0) + exported_count: int = Field(default=0) + skipped_count: int = Field(default=0) + failed_count: int = Field(default=0) + quarantined_count: int = Field(default=0) + errors: list[str] = Field(default_factory=list) + duration_ms: float = Field(default=0.0) + dry_run: bool = Field(default=False) diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/exporter.py b/src/layerlens/instrument/adapters/frameworks/langfuse/exporter.py new file mode 100644 index 00000000..d8df541f --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/exporter.py @@ -0,0 +1,141 @@ +""" +Langfuse Trace Exporter + +Reverse-maps STRATIX events to Langfuse traces and pushes them via the API. +""" + +from __future__ import annotations + +import uuid +import logging +from typing import Any +from datetime import datetime, timezone + +UTC = timezone.utc # Python 3.11+ has datetime.UTC; alias for 3.9/3.10 compat. + +from layerlens.instrument.adapters.frameworks.langfuse.client import LangfuseAPIError, LangfuseAPIClient +from layerlens.instrument.adapters.frameworks.langfuse.config import SyncState, SyncResult, SyncDirection +from layerlens.instrument.adapters.frameworks.langfuse.mapper import LayerLensToLangfuseMapper + +logger = logging.getLogger(__name__) + + +class TraceExporter: + """ + Export pipeline for STRATIX -> Langfuse. + + Steps: + 1. Group STRATIX events by trace ID + 2. Reverse-map to Langfuse trace + observations + 3. Create trace and observations via Langfuse API + 4. Tag with 'stratix-exported' to prevent re-import + """ + + def __init__( + self, + client: LangfuseAPIClient, + state: SyncState, + ) -> None: + self._client = client + self._state = state + self._mapper = LayerLensToLangfuseMapper() + + def export_traces( + self, + events_by_trace: dict[str, list[dict[str, Any]]], + trace_ids: list[str] | None = None, + dry_run: bool = False, + ) -> SyncResult: + """ + Export STRATIX traces to Langfuse. + + Args: + events_by_trace: Dict mapping trace_id -> list of STRATIX event dicts. + trace_ids: Optional filter — only export these trace IDs. + dry_run: If True, count but don't actually export. + + Returns: + SyncResult with export statistics. + """ + result = SyncResult(direction=SyncDirection.EXPORT, dry_run=dry_run) + + ids_to_export = trace_ids or list(events_by_trace.keys()) + + for trace_id in ids_to_export: + events = events_by_trace.get(trace_id, []) + if not events: + result.skipped_count += 1 + continue + + # Loop prevention: skip traces that were imported from Langfuse + if trace_id in self._state.imported_trace_ids: + result.skipped_count += 1 + continue + + # Skip already exported + if trace_id in self._state.exported_trace_ids: + result.skipped_count += 1 + continue + + if dry_run: + result.exported_count += 1 + continue + + # Map STRATIX events to Langfuse structure + try: + langfuse_data = self._mapper.map_events_to_trace(events, trace_id=trace_id) + except Exception as e: + logger.warning("Failed to map trace %s for export: %s", trace_id, e) + result.failed_count += 1 + result.errors.append(f"Trace {trace_id} mapping: {e}") + continue + + # Push to Langfuse + try: + self._push_to_langfuse(langfuse_data) + except LangfuseAPIError as e: + logger.warning("Failed to export trace %s: %s", trace_id, e) + result.failed_count += 1 + result.errors.append(f"Trace {trace_id} export: {e}") + continue + + # Record success + self._state.record_export(trace_id, datetime.now(UTC)) + result.exported_count += 1 + + return result + + def _push_to_langfuse(self, langfuse_data: dict[str, Any]) -> None: + """Push a mapped trace + observations to Langfuse via batch ingestion.""" + trace_body = langfuse_data.get("trace", {}) + observations = langfuse_data.get("observations", []) + + # Build batch events + batch: list[dict[str, Any]] = [] + now = datetime.now(UTC).isoformat() + + # Trace create event + batch.append( + { + "id": str(uuid.uuid4()), + "type": "trace-create", + "timestamp": now, + "body": trace_body, + } + ) + + # Observation create events + for obs in observations: + obs_type = obs.get("type", "SPAN").upper() + event_type = "generation-create" if obs_type == "GENERATION" else "span-create" + + batch.append( + { + "id": str(uuid.uuid4()), + "type": event_type, + "timestamp": now, + "body": obs, + } + ) + + self._client.ingestion_batch(batch) diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/importer.py b/src/layerlens/instrument/adapters/frameworks/langfuse/importer.py new file mode 100644 index 00000000..e309fda3 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/importer.py @@ -0,0 +1,177 @@ +""" +Langfuse Trace Importer + +Fetches traces from Langfuse, maps them to STRATIX events, deduplicates, +and ingests via the STRATIX pipeline. +""" + +from __future__ import annotations + +import logging +from typing import Any +from datetime import datetime, timezone + +UTC = timezone.utc # Python 3.11+ has datetime.UTC; alias for 3.9/3.10 compat. + +from layerlens.instrument.adapters.frameworks.langfuse.client import LangfuseAPIError, LangfuseAPIClient +from layerlens.instrument.adapters.frameworks.langfuse.config import SyncState, SyncResult, SyncDirection +from layerlens.instrument.adapters.frameworks.langfuse.mapper import LangfuseToLayerLensMapper + +logger = logging.getLogger(__name__) + + +class TraceImporter: + """ + Import pipeline for Langfuse -> STRATIX. + + Steps: + 1. List traces from Langfuse (with filters) + 2. Fetch full trace with observations + 3. Map to STRATIX events + 4. Deduplicate against previously imported traces + 5. Ingest via STRATIX emit or pipeline + """ + + def __init__( + self, + client: LangfuseAPIClient, + state: SyncState, + ) -> None: + self._client = client + self._state = state + self._mapper = LangfuseToLayerLensMapper() + + def import_traces( + self, + stratix: Any | None = None, + since: datetime | None = None, + tags: list[str] | None = None, + limit: int | None = None, + dry_run: bool = False, + ) -> SyncResult: + """ + Import traces from Langfuse. + + Args: + stratix: STRATIX instance for event emission (or pipeline). + since: Only import traces after this timestamp. + tags: Filter by trace tags. + limit: Max number of traces to import. + dry_run: If True, count but don't actually import. + + Returns: + SyncResult with import statistics. + """ + result = SyncResult(direction=SyncDirection.IMPORT, dry_run=dry_run) + + # Fetch trace list + try: + traces = self._client.get_all_traces( + tags=tags, + from_timestamp=since, + ) + except LangfuseAPIError as e: + result.errors.append(f"Failed to list traces: {e}") + result.failed_count = 1 + return result + + if limit: + traces = traces[:limit] + + for trace_summary in traces: + trace_id = trace_summary.get("id", "") + + # Skip quarantined traces + if self._state.is_quarantined(trace_id): + result.quarantined_count += 1 + continue + + # Dedup: skip already imported (unless updated_at is newer) + if trace_id in self._state.imported_trace_ids: + result.skipped_count += 1 + continue + + # Skip traces exported by STRATIX (loop prevention) + trace_tags = trace_summary.get("tags", []) or [] + if "stratix-exported" in trace_tags: + result.skipped_count += 1 + continue + + if dry_run: + result.imported_count += 1 + continue + + # Fetch full trace + try: + full_trace = self._client.get_trace(trace_id) + except LangfuseAPIError as e: + logger.warning("Failed to fetch trace %s: %s", trace_id, e) + is_quarantined = self._state.record_failure(trace_id) + if is_quarantined: + result.quarantined_count += 1 + result.failed_count += 1 + result.errors.append(f"Trace {trace_id}: {e}") + continue + + # Map to STRATIX events + try: + events = self._mapper.map_trace(full_trace) + except Exception as e: + logger.warning("Failed to map trace %s: %s", trace_id, e) + is_quarantined = self._state.record_failure(trace_id) + if is_quarantined: + result.quarantined_count += 1 + result.failed_count += 1 + result.errors.append(f"Trace {trace_id} mapping: {e}") + continue + + if not events: + result.skipped_count += 1 + continue + + # Ingest events + try: + self._ingest_events(events, stratix) + except Exception as e: + logger.warning("Failed to ingest trace %s: %s", trace_id, e) + is_quarantined = self._state.record_failure(trace_id) + if is_quarantined: + result.quarantined_count += 1 + result.failed_count += 1 + result.errors.append(f"Trace {trace_id} ingestion: {e}") + continue + + # Record success + updated_at = self._parse_timestamp( + full_trace.get("updatedAt", full_trace.get("timestamp")) + ) + self._state.record_import(trace_id, updated_at) + result.imported_count += 1 + + return result + + def _ingest_events( + self, + events: list[dict[str, Any]], + stratix: Any | None, + ) -> None: + """Ingest mapped events via STRATIX emit or pipeline.""" + if stratix is None or not bool(stratix): + return + + for event in events: + event_type = event.get("event_type", "") + payload = event.get("payload", {}) + stratix.emit(event_type, payload) + + @staticmethod + def _parse_timestamp(value: Any) -> datetime: + """Parse a timestamp string to datetime, or return now.""" + if isinstance(value, datetime): + return value + if isinstance(value, str): + try: + return datetime.fromisoformat(value.replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + return datetime.now(UTC) diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/langfuse/lifecycle.py new file mode 100644 index 00000000..e8d6c9b6 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/lifecycle.py @@ -0,0 +1,338 @@ +""" +Langfuse Adapter Lifecycle + +Main LangfuseAdapter class extending BaseAdapter. +Manages connection, health, import/export, and sync operations. +""" + +from __future__ import annotations + +import time +import uuid +import logging +from typing import Any +from datetime import datetime, timezone + +UTC = timezone.utc # Python 3.11+ has datetime.UTC; alias for 3.9/3.10 compat. + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters.frameworks.langfuse.sync import BidirectionalSync +from layerlens.instrument.adapters.frameworks.langfuse.client import LangfuseAPIError, LangfuseAPIClient +from layerlens.instrument.adapters.frameworks.langfuse.config import ( + SyncState, + SyncResult, + SyncDirection, + LangfuseConfig, +) +from layerlens.instrument.adapters.frameworks.langfuse.exporter import TraceExporter +from layerlens.instrument.adapters.frameworks.langfuse.importer import TraceImporter + +logger = logging.getLogger(__name__) + + +class LangfuseAdapter(BaseAdapter): + """ + LayerLens adapter for Langfuse integration. + + Unlike other adapters that wrap running code in real-time, the Langfuse + adapter is a data import/export pipeline that communicates with a remote + Langfuse HTTP API to pull/push traces in batch. + """ + + FRAMEWORK = "langfuse" + VERSION = "0.1.0" + # The adapter's own config layer + # (``frameworks/langfuse/config.py`` line 13) imports + # ``from pydantic import field_validator`` — a v2-only decorator. + # Pydantic v1 has ``validator``; ``field_validator`` was added in v2 + # (see pydantic v2 migration guide). Importing this adapter under v1 + # raises ``ImportError`` in config.py. + requires_pydantic = PydanticCompat.V2_ONLY + + def __init__( + self, + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + config: LangfuseConfig | None = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._config: LangfuseConfig | None = config + self._client: LangfuseAPIClient | None = None + self._sync_state = SyncState() + self._importer: TraceImporter | None = None + self._exporter: TraceExporter | None = None + self._sync: BidirectionalSync | None = None + self._last_health_check: datetime | None = None + self._langfuse_healthy = False + + # --- BaseAdapter abstract methods --- + + def connect(self, config: LangfuseConfig | None = None) -> None: + """ + Connect to the Langfuse API. + + Creates the HTTP client and validates credentials with a health check. + """ + if config: + self._config = config + + if self._config is None: + # Connect without a config — adapter is usable but not connected to Langfuse + self._connected = True + self._status = AdapterStatus.HEALTHY + return + + self._client = LangfuseAPIClient( + public_key=self._config.public_key, + secret_key=self._config.secret_key, + host=self._config.host, + max_retries=self._config.max_retries, + ) + + # Validate credentials + try: + self._client.health_check() + self._langfuse_healthy = True + except LangfuseAPIError as e: + logger.warning("Langfuse health check failed: %s", e) + self._langfuse_healthy = False + + # Initialize sub-components + self._importer = TraceImporter(self._client, self._sync_state) + self._exporter = TraceExporter(self._client, self._sync_state) + self._sync = BidirectionalSync( + importer=self._importer, + exporter=self._exporter, + state=self._sync_state, + ) + + self._connected = True + self._status = AdapterStatus.HEALTHY if self._langfuse_healthy else AdapterStatus.DEGRADED + self._last_health_check = datetime.now(UTC) + + def disconnect(self) -> None: + """Disconnect from Langfuse.""" + self._client = None + self._importer = None + self._exporter = None + self._sync = None + self._connected = False + self._status = AdapterStatus.DISCONNECTED + self._langfuse_healthy = False + + def health_check(self) -> AdapterHealth: + """Return health status including Langfuse API reachability.""" + message = None + if self._client and self._connected: + try: + self._client.health_check() + self._langfuse_healthy = True + message = "Langfuse API reachable" + except LangfuseAPIError as e: + self._langfuse_healthy = False + message = f"Langfuse API unreachable: {e}" + self._status = AdapterStatus.DEGRADED + elif not self._config: + message = "No Langfuse config — adapter connected without remote API" + else: + message = "Not connected" + + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + framework_version=None, + adapter_version=self.VERSION, + message=message, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + """Return metadata about this adapter.""" + return AdapterInfo( + name="LangfuseAdapter", + version=self.VERSION, + framework=self.FRAMEWORK, + framework_version=None, + capabilities=[ + AdapterCapability.TRACE_TOOLS, + AdapterCapability.TRACE_MODELS, + AdapterCapability.REPLAY, + ], + author="STRATIX Team", + description="Bidirectional trace sync between STRATIX and Langfuse", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + """Serialize accumulated trace events for replay.""" + return ReplayableTrace( + adapter_name="LangfuseAdapter", + framework=self.FRAMEWORK, + trace_id=str(uuid.uuid4()), + events=list(self._trace_events), + config=self._config.model_dump() if self._config else {}, + metadata={ + "sync_state": { + "imported": len(self._sync_state.imported_trace_ids), + "exported": len(self._sync_state.exported_trace_ids), + "quarantined": len(self._sync_state.quarantined_trace_ids), + }, + }, + ) + + # --- Import/Export/Sync API --- + + def import_traces( + self, + since: datetime | None = None, + tags: list[str] | None = None, + limit: int | None = None, + dry_run: bool = False, + ) -> SyncResult: + """ + Import traces from Langfuse into STRATIX. + + Args: + since: Only import traces updated after this timestamp. + tags: Filter by Langfuse trace tags. + limit: Maximum number of traces to import. + dry_run: If True, report what would be imported without importing. + + Returns: + SyncResult with import statistics. + """ + if self._importer is None: + return SyncResult( + direction=SyncDirection.IMPORT, + errors=["Adapter not connected to Langfuse API"], + ) + + start_time = time.monotonic() + effective_since = since or (self._config.since if self._config else None) + effective_tags = tags or (self._config.tag_filter if self._config else None) + + result = self._importer.import_traces( + stratix=self._stratix, + since=effective_since, + tags=effective_tags, + limit=limit, + dry_run=dry_run, + ) + result.duration_ms = (time.monotonic() - start_time) * 1000 + return result + + def export_traces( + self, + events_by_trace: dict[str, list[dict[str, Any]]] | None = None, + trace_ids: list[str] | None = None, + dry_run: bool = False, + ) -> SyncResult: + """ + Export STRATIX traces to Langfuse. + + Args: + events_by_trace: Dict mapping trace_id -> list of STRATIX event dicts. + trace_ids: List of trace IDs to export (requires events_by_trace). + dry_run: If True, report what would be exported without exporting. + + Returns: + SyncResult with export statistics. + """ + if self._exporter is None: + return SyncResult( + direction=SyncDirection.EXPORT, + errors=["Adapter not connected to Langfuse API"], + ) + + start_time = time.monotonic() + result = self._exporter.export_traces( + events_by_trace=events_by_trace or {}, + trace_ids=trace_ids, + dry_run=dry_run, + ) + result.duration_ms = (time.monotonic() - start_time) * 1000 + return result + + def sync( + self, + direction: SyncDirection | None = None, + since: datetime | None = None, + dry_run: bool = False, + events_by_trace: dict[str, list[dict[str, Any]]] | None = None, + ) -> SyncResult: + """ + Run a sync cycle in the configured direction. + + Args: + direction: Override the configured sync direction. + since: Override the since timestamp. + dry_run: If True, report what would be synced without making changes. + events_by_trace: Required for export/bidirectional — STRATIX events to export. + + Returns: + SyncResult with combined statistics. + """ + if self._sync is None: + return SyncResult( + direction=direction or SyncDirection.IMPORT, + errors=["Adapter not connected to Langfuse API"], + ) + + effective_direction = direction or ( + self._config.mode if self._config else SyncDirection.IMPORT + ) + start_time = time.monotonic() + + result = self._sync.run( + stratix=self._stratix, + direction=effective_direction, + since=since, + dry_run=dry_run, + events_by_trace=events_by_trace or {}, + tags=self._config.tag_filter if self._config else None, + ) + result.duration_ms = (time.monotonic() - start_time) * 1000 + return result + + # --- State access --- + + @property + def sync_state(self) -> SyncState: + """Return the current sync state.""" + return self._sync_state + + @property + def config(self) -> LangfuseConfig | None: + """Return the current configuration.""" + return self._config + + def get_status(self) -> dict[str, Any]: + """Return a status summary for CLI/API use.""" + return { + "connected": self._connected, + "langfuse_healthy": self._langfuse_healthy, + "host": self._config.host if self._config else None, + "mode": self._config.mode.value if self._config else None, + "imported_traces": len(self._sync_state.imported_trace_ids), + "exported_traces": len(self._sync_state.exported_trace_ids), + "quarantined_traces": len(self._sync_state.quarantined_trace_ids), + "last_import_cursor": ( + self._sync_state.last_import_cursor.isoformat() + if self._sync_state.last_import_cursor + else None + ), + "last_export_cursor": ( + self._sync_state.last_export_cursor.isoformat() + if self._sync_state.last_export_cursor + else None + ), + } diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/mapper.py b/src/layerlens/instrument/adapters/frameworks/langfuse/mapper.py new file mode 100644 index 00000000..2dd0eb75 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/mapper.py @@ -0,0 +1,624 @@ +""" +Langfuse <-> STRATIX Bidirectional Field Mapper + +Maps Langfuse trace/observation structures to STRATIX canonical events +and vice versa. +""" + +from __future__ import annotations + +import uuid +import logging +from typing import Any +from datetime import datetime, timezone + +UTC = timezone.utc # Python 3.11+ has datetime.UTC; alias for 3.9/3.10 compat. + +logger = logging.getLogger(__name__) + + +class LangfuseToLayerLensMapper: + """ + Maps Langfuse traces and observations to STRATIX canonical event dicts. + + Each Langfuse trace produces multiple STRATIX events: + - trace.input -> agent.input (L1) + - trace.output -> agent.output (L1) + - span -> agent.code (L2) or tool.call (L5a) + - generation -> model.invoke (L3) + cost.record (Cross) + - metadata -> environment.config (L4a) + - errors -> policy.violation (Cross) + """ + + def map_trace(self, trace: dict[str, Any]) -> list[dict[str, Any]]: + """ + Map a complete Langfuse trace (with observations) to STRATIX events. + + Args: + trace: Langfuse trace dict from the API, including nested observations. + + Returns: + List of STRATIX event dicts ready for ingestion. + """ + trace_id = trace.get("id", str(uuid.uuid4())) + events: list[dict[str, Any]] = [] + timestamp = trace.get("timestamp", datetime.now(UTC).isoformat()) + seq = 0 + + # Trace-level metadata (L4a) + metadata = trace.get("metadata") + if metadata: + events.append( + self._make_event( + event_type="environment.config", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq, + payload={ + "config_type": "langfuse_trace_metadata", + "config": metadata, + "framework": "langfuse", + }, + langfuse_metadata=self._extract_trace_metadata(trace), + ) + ) + seq += 1 + + # Trace input -> agent.input (L1) + trace_input = trace.get("input") + if trace_input is not None: + events.append( + self._make_event( + event_type="agent.input", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq, + payload={ + "agent_id": trace.get("name", "langfuse_agent"), + "input_text": self._to_str(trace_input), + "input": trace_input, + "framework": "langfuse", + }, + langfuse_metadata=self._extract_trace_metadata(trace), + ) + ) + seq += 1 + + # Sort observations by start_time for temporal ordering + observations = trace.get("observations", []) + observations = sorted( + observations, + key=lambda o: o.get("startTime", o.get("start_time", "")), + ) + + for obs in observations: + obs_events = self._map_observation(obs, trace_id, seq) + events.extend(obs_events) + seq += len(obs_events) + + # Trace output -> agent.output (L1) + trace_output = trace.get("output") + if trace_output is not None: + end_time = trace.get("endTime", trace.get("end_time", timestamp)) + events.append( + self._make_event( + event_type="agent.output", + trace_id=trace_id, + timestamp=end_time or timestamp, + sequence_id=seq, + payload={ + "agent_id": trace.get("name", "langfuse_agent"), + "output_text": self._to_str(trace_output), + "output": trace_output, + "framework": "langfuse", + }, + langfuse_metadata=self._extract_trace_metadata(trace), + ) + ) + + return events + + def _map_observation( + self, + obs: dict[str, Any], + trace_id: str, + start_seq: int, + ) -> list[dict[str, Any]]: + """Map a single Langfuse observation to STRATIX event(s).""" + obs_type = obs.get("type", "SPAN").upper() + timestamp = obs.get("startTime", obs.get("start_time", "")) + + if obs_type == "GENERATION": + return self._map_generation(obs, trace_id, timestamp, start_seq) + elif obs_type == "SPAN": + return self._map_span(obs, trace_id, timestamp, start_seq) + else: + # EVENT or unknown type — map as agent.code + return self._map_span(obs, trace_id, timestamp, start_seq) + + def _map_generation( + self, + obs: dict[str, Any], + trace_id: str, + timestamp: str, + seq: int, + ) -> list[dict[str, Any]]: + """Map a Langfuse generation to model.invoke + cost.record.""" + events: list[dict[str, Any]] = [] + + model = obs.get("model", obs.get("modelId")) + usage = obs.get("usage", obs.get("promptTokens")) + + # Compute latency + latency_ms = self._compute_latency_ms(obs) + + # Normalize token usage + if isinstance(usage, dict): + prompt_tokens = usage.get("promptTokens", usage.get("input", 0)) + completion_tokens = usage.get("completionTokens", usage.get("output", 0)) + total_tokens = usage.get("totalTokens", usage.get("total", 0)) + else: + prompt_tokens = obs.get("promptTokens", 0) + completion_tokens = obs.get("completionTokens", 0) + total_tokens = obs.get("totalTokens", 0) + + # model.invoke (L3) + payload: dict[str, Any] = { + "provider": "langfuse", + "model": model, + "tokens_prompt": prompt_tokens or 0, + "tokens_completion": completion_tokens or 0, + "tokens_total": total_tokens or (prompt_tokens or 0) + (completion_tokens or 0), + "framework": "langfuse", + } + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + # Include model parameters if present + model_params = obs.get("modelParameters") + if model_params: + payload["parameters"] = model_params + + # Check for errors + level = obs.get("level", "").upper() + status_message = obs.get("statusMessage", "") + if level == "ERROR": + payload["error"] = status_message or "Generation error" + + events.append( + self._make_event( + event_type="model.invoke", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq, + payload=payload, + ) + ) + + # cost.record (Cross-cutting) + total_cost = obs.get("totalCost", obs.get("calculatedTotalCost")) + if total_cost is not None and total_cost > 0: + events.append( + self._make_event( + event_type="cost.record", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq + 1, + payload={ + "model": model, + "cost_usd": total_cost, + "tokens_prompt": prompt_tokens or 0, + "tokens_completion": completion_tokens or 0, + "framework": "langfuse", + }, + ) + ) + + # Error/warning observations -> policy.violation + if level in ("ERROR", "WARNING"): + events.append( + self._make_event( + event_type="policy.violation", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq + len(events), + payload={ + "violation_type": "error" if level == "ERROR" else "warning", + "description": status_message or f"Generation {level.lower()}", + "source": "langfuse_observation", + "observation_id": obs.get("id"), + "framework": "langfuse", + }, + ) + ) + + return events + + def _map_span( + self, + obs: dict[str, Any], + trace_id: str, + timestamp: str, + seq: int, + ) -> list[dict[str, Any]]: + """Map a Langfuse span to tool.call or agent.code.""" + name = obs.get("name", "") + obs_input = obs.get("input") + obs_output = obs.get("output") + latency_ms = self._compute_latency_ms(obs) + level = obs.get("level", "").upper() + status_message = obs.get("statusMessage", "") + + # Determine if this is a tool call (metadata hint or naming convention) + metadata = obs.get("metadata", {}) or {} + is_tool = ( + metadata.get("type") == "TOOL" + or name.lower().startswith("tool_") + or name.lower().startswith("tool:") + or metadata.get("tool_name") + ) + + events: list[dict[str, Any]] = [] + + if is_tool: + # tool.call (L5a) + payload: dict[str, Any] = { + "tool_name": metadata.get("tool_name", name), + "framework": "langfuse", + } + if obs_input is not None: + payload["input"] = obs_input + if obs_output is not None: + payload["output"] = obs_output + if latency_ms is not None: + payload["latency_ms"] = latency_ms + if level == "ERROR": + payload["error"] = status_message or "Tool error" + + events.append( + self._make_event( + event_type="tool.call", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq, + payload=payload, + ) + ) + else: + # agent.code (L2) + payload = { + "step_name": name, + "framework": "langfuse", + } + if obs_input is not None: + payload["input"] = obs_input + if obs_output is not None: + payload["output"] = obs_output + if latency_ms is not None: + payload["latency_ms"] = latency_ms + + events.append( + self._make_event( + event_type="agent.code", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq, + payload=payload, + ) + ) + + # Error/warning -> policy.violation + if level in ("ERROR", "WARNING"): + events.append( + self._make_event( + event_type="policy.violation", + trace_id=trace_id, + timestamp=timestamp, + sequence_id=seq + 1, + payload={ + "violation_type": "error" if level == "ERROR" else "warning", + "description": status_message or f"Span {level.lower()}", + "source": "langfuse_observation", + "observation_id": obs.get("id"), + "framework": "langfuse", + }, + ) + ) + + return events + + # --- Helpers --- + + @staticmethod + def _make_event( + event_type: str, + trace_id: str, + timestamp: str, + sequence_id: int, + payload: dict[str, Any], + langfuse_metadata: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Construct a normalized STRATIX event dict.""" + event: dict[str, Any] = { + "event_type": event_type, + "trace_id": trace_id, + "timestamp": timestamp, + "sequence_id": sequence_id, + "payload": payload, + } + if langfuse_metadata: + event["metadata"] = langfuse_metadata + return event + + @staticmethod + def _extract_trace_metadata(trace: dict[str, Any]) -> dict[str, Any]: + """Extract Langfuse-specific metadata from a trace.""" + meta: dict[str, Any] = { + "langfuse_trace_id": trace.get("id"), + } + if trace.get("sessionId"): + meta["langfuse_session_id"] = trace["sessionId"] + if trace.get("userId"): + meta["langfuse_user_id"] = trace["userId"] + if trace.get("tags"): + meta["langfuse_tags"] = trace["tags"] + if trace.get("scores"): + meta["langfuse_scores"] = trace["scores"] + return meta + + @staticmethod + def _compute_latency_ms(obs: dict[str, Any]) -> float | None: + """Compute latency from observation start/end times.""" + start = obs.get("startTime", obs.get("start_time")) + end = obs.get("endTime", obs.get("end_time")) + if not start or not end: + return None + try: + if isinstance(start, str): + start_dt = datetime.fromisoformat(start.replace("Z", "+00:00")) + else: + start_dt = start + if isinstance(end, str): + end_dt = datetime.fromisoformat(end.replace("Z", "+00:00")) + else: + end_dt = end + delta = end_dt - start_dt + return delta.total_seconds() * 1000 + except (ValueError, TypeError): + return None + + @staticmethod + def _to_str(value: Any) -> str: + """Convert a value to string representation.""" + if isinstance(value, str): + return value + if isinstance(value, dict): + import json + + return json.dumps(value) + return str(value) + + +class LayerLensToLangfuseMapper: + """ + Maps STRATIX canonical events back to Langfuse trace/observation structures. + + Used for exporting STRATIX traces to Langfuse. + """ + + def map_events_to_trace( + self, + events: list[dict[str, Any]], + trace_id: str | None = None, + ) -> dict[str, Any]: + """ + Map a list of STRATIX events to a Langfuse trace with observations. + + Returns a dict with 'trace' (trace body) and 'observations' (list of observations). + """ + trace_id = trace_id or str(uuid.uuid4()) + + trace_body: dict[str, Any] = { + "id": trace_id, + "name": "stratix-export", + "tags": ["stratix-exported"], + "metadata": {"stratix_trace_id": trace_id}, + } + observations: list[dict[str, Any]] = [] + + for event in events: + event_type = event.get("event_type", "") + payload = event.get("payload", {}) + timestamp = event.get("timestamp", datetime.now(UTC).isoformat()) + event.get("metadata", {}) + + if event_type == "agent.input": + trace_body["input"] = payload.get("input", payload.get("input_text")) + if not trace_body.get("name") or trace_body["name"] == "stratix-export": + agent_id = payload.get("agent_id") + if agent_id: + trace_body["name"] = agent_id + + elif event_type == "agent.output": + trace_body["output"] = payload.get("output", payload.get("output_text")) + + elif event_type == "model.invoke": + obs = self._make_generation(payload, timestamp, trace_id) + observations.append(obs) + + elif event_type == "tool.call": + obs = self._make_tool_span(payload, timestamp, trace_id) + observations.append(obs) + + elif event_type == "agent.code": + obs = self._make_default_span(payload, timestamp, trace_id) + observations.append(obs) + + elif event_type == "cost.record": + # Cost is attached to corresponding generation — find matching + self._attach_cost(observations, payload) + + elif event_type == "environment.config": + config = payload.get("config", {}) + existing_meta = trace_body.get("metadata", {}) + existing_meta["environment_config"] = config + trace_body["metadata"] = existing_meta + + elif event_type == "agent.handoff": + obs = self._make_handoff_span(payload, timestamp, trace_id) + observations.append(obs) + + elif event_type == "agent.state.change": + obs = self._make_state_span(payload, timestamp, trace_id) + observations.append(obs) + + return {"trace": trace_body, "observations": observations} + + @staticmethod + def _make_generation( + payload: dict[str, Any], + timestamp: str, + trace_id: str, + ) -> dict[str, Any]: + """Create a Langfuse generation observation from model.invoke event.""" + gen: dict[str, Any] = { + "id": str(uuid.uuid4()), + "traceId": trace_id, + "type": "GENERATION", + "name": payload.get("model", "unknown-model"), + "startTime": timestamp, + "model": payload.get("model"), + } + # Token usage + usage: dict[str, Any] = {} + if payload.get("tokens_prompt"): + usage["promptTokens"] = payload["tokens_prompt"] + if payload.get("tokens_completion"): + usage["completionTokens"] = payload["tokens_completion"] + if payload.get("tokens_total"): + usage["totalTokens"] = payload["tokens_total"] + if usage: + gen["usage"] = usage + + # Parameters + if payload.get("parameters"): + gen["modelParameters"] = payload["parameters"] + + # Latency -> end time + if payload.get("latency_ms"): + gen["endTime"] = timestamp # Approximate + + # Error + if payload.get("error"): + gen["level"] = "ERROR" + gen["statusMessage"] = payload["error"] + + return gen + + @staticmethod + def _make_tool_span( + payload: dict[str, Any], + timestamp: str, + trace_id: str, + ) -> dict[str, Any]: + """Create a Langfuse TOOL span from tool.call event.""" + span: dict[str, Any] = { + "id": str(uuid.uuid4()), + "traceId": trace_id, + "type": "SPAN", + "name": payload.get("tool_name", "unknown-tool"), + "startTime": timestamp, + "metadata": {"type": "TOOL"}, + } + if payload.get("input") is not None: + span["input"] = payload["input"] + if payload.get("output") is not None: + span["output"] = payload["output"] + if payload.get("error"): + span["level"] = "ERROR" + span["statusMessage"] = payload["error"] + return span + + @staticmethod + def _make_default_span( + payload: dict[str, Any], + timestamp: str, + trace_id: str, + ) -> dict[str, Any]: + """Create a Langfuse DEFAULT span from agent.code event.""" + span: dict[str, Any] = { + "id": str(uuid.uuid4()), + "traceId": trace_id, + "type": "SPAN", + "name": payload.get("step_name", "execution-step"), + "startTime": timestamp, + } + if payload.get("input") is not None: + span["input"] = payload["input"] + if payload.get("output") is not None: + span["output"] = payload["output"] + return span + + @staticmethod + def _make_handoff_span( + payload: dict[str, Any], + timestamp: str, + trace_id: str, + ) -> dict[str, Any]: + """Create a Langfuse span for agent.handoff event.""" + return { + "id": str(uuid.uuid4()), + "traceId": trace_id, + "type": "SPAN", + "name": f"handoff:{payload.get('from_agent', '?')}->{payload.get('to_agent', '?')}", + "startTime": timestamp, + "metadata": { + "type": "HANDOFF", + "from_agent": payload.get("from_agent"), + "to_agent": payload.get("to_agent"), + "context": payload.get("context"), + }, + } + + @staticmethod + def _make_state_span( + payload: dict[str, Any], + timestamp: str, + trace_id: str, + ) -> dict[str, Any]: + """Create a Langfuse span for agent.state.change event.""" + return { + "id": str(uuid.uuid4()), + "traceId": trace_id, + "type": "SPAN", + "name": f"state-change:{payload.get('state_type', 'unknown')}", + "startTime": timestamp, + "metadata": { + "type": "STATE_CHANGE", + "before": payload.get("before"), + "after": payload.get("after"), + }, + } + + @staticmethod + def _attach_cost( + observations: list[dict[str, Any]], + cost_payload: dict[str, Any], + ) -> None: + """Attach cost to the matching generation observation.""" + model = cost_payload.get("model") + cost_usd = cost_payload.get("cost_usd") + if cost_usd is None: + return + + # Find a matching generation by model name + for obs in reversed(observations): + if obs.get("type") == "GENERATION": # noqa: SIM102 + if model is None or obs.get("model") == model: + obs["totalCost"] = cost_usd + return + # No match — attach to last generation if any + for obs in reversed(observations): + if obs.get("type") == "GENERATION": + obs["totalCost"] = cost_usd + return diff --git a/src/layerlens/instrument/adapters/frameworks/langfuse/sync.py b/src/layerlens/instrument/adapters/frameworks/langfuse/sync.py new file mode 100644 index 00000000..9bd43ed2 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langfuse/sync.py @@ -0,0 +1,89 @@ +""" +Langfuse Bidirectional Sync + +Coordinates import and export with cursor tracking and conflict resolution. +""" + +from __future__ import annotations + +import logging +from typing import Any +from datetime import datetime + +from layerlens.instrument.adapters.frameworks.langfuse.config import SyncState, SyncResult, SyncDirection +from layerlens.instrument.adapters.frameworks.langfuse.exporter import TraceExporter +from layerlens.instrument.adapters.frameworks.langfuse.importer import TraceImporter + +logger = logging.getLogger(__name__) + + +class BidirectionalSync: + """ + Orchestrates bidirectional sync between Langfuse and STRATIX. + + Uses cursor-based incremental sync to minimize API calls. + """ + + def __init__( + self, + importer: TraceImporter, + exporter: TraceExporter, + state: SyncState, + ) -> None: + self._importer = importer + self._exporter = exporter + self._state = state + + def run( + self, + stratix: Any | None = None, + direction: SyncDirection = SyncDirection.BIDIRECTIONAL, + since: datetime | None = None, + dry_run: bool = False, + events_by_trace: dict[str, list[dict[str, Any]]] | None = None, + tags: list[str] | None = None, + ) -> SyncResult: + """ + Run a sync cycle. + + Args: + stratix: STRATIX instance for event emission. + direction: Sync direction (import, export, or bidirectional). + since: Override since timestamp. + dry_run: If True, report what would happen without making changes. + events_by_trace: STRATIX events for export (required for export/bidirectional). + tags: Filter tags for import. + + Returns: + Combined SyncResult. + """ + result = SyncResult(direction=direction, dry_run=dry_run) + + # Import phase + if direction in (SyncDirection.IMPORT, SyncDirection.BIDIRECTIONAL): + effective_since = since or self._state.last_import_cursor + import_result = self._importer.import_traces( + stratix=stratix, + since=effective_since, + tags=tags, + dry_run=dry_run, + ) + result.imported_count = import_result.imported_count + result.skipped_count += import_result.skipped_count + result.failed_count += import_result.failed_count + result.quarantined_count += import_result.quarantined_count + result.errors.extend(import_result.errors) + + # Export phase + if direction in (SyncDirection.EXPORT, SyncDirection.BIDIRECTIONAL): # noqa: SIM102 + if events_by_trace: + export_result = self._exporter.export_traces( + events_by_trace=events_by_trace, + dry_run=dry_run, + ) + result.exported_count = export_result.exported_count + result.skipped_count += export_result.skipped_count + result.failed_count += export_result.failed_count + result.errors.extend(export_result.errors) + + return result diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/__init__.py b/src/layerlens/instrument/adapters/frameworks/langgraph/__init__.py new file mode 100644 index 00000000..fa2d0c39 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/__init__.py @@ -0,0 +1,58 @@ +""" +STRATIX LangGraph Adapter + +Integrates STRATIX tracing with LangGraph agent framework. + +Usage: + from layerlens.instrument.adapters.frameworks.langgraph import ( + LayerLensLangGraphAdapter, + trace_langgraph_tool, + wrap_llm_for_langgraph, + ) + + # Create adapter + adapter = LayerLensLangGraphAdapter(stratix_instance) + + # Wrap your graph + traced_graph = adapter.wrap_graph(my_graph) + + # Or use decorators for individual components + @trace_langgraph_tool + def my_tool(state): + ... +""" + +from __future__ import annotations + +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat, requires_pydantic + +# Round-2 deliberation item 20: LangGraph >=0.2 inherits langchain-core's +# Pydantic v2 requirement; fail fast under v1 with a clear message. +requires_pydantic(PydanticCompat.V2_ONLY) + +from layerlens.instrument.adapters.frameworks.langgraph.llm import TracedLLM, wrap_llm_for_langgraph +from layerlens.instrument.adapters.frameworks.langgraph.nodes import NodeTracer, trace_node +from layerlens.instrument.adapters.frameworks.langgraph.state import LangGraphStateAdapter +from layerlens.instrument.adapters.frameworks.langgraph.tools import trace_langgraph_tool +from layerlens.instrument.adapters.frameworks.langgraph.handoff import HandoffDetector, detect_handoff +from layerlens.instrument.adapters.frameworks.langgraph.lifecycle import LayerLensLangGraphAdapter + +# Registry lazy-loading convention +ADAPTER_CLASS = LayerLensLangGraphAdapter + +__all__ = [ + "LangGraphStateAdapter", + "LayerLensLangGraphAdapter", + "trace_node", + "NodeTracer", + "trace_langgraph_tool", + "wrap_llm_for_langgraph", + "TracedLLM", + "HandoffDetector", + "detect_handoff", + "ADAPTER_CLASS", +] + + +# Backward-compat aliases for users coming from ateam. +STRATIXLangGraphAdapter = LayerLensLangGraphAdapter # noqa: N816 - backward-compat alias for ateam users diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/handoff.py b/src/layerlens/instrument/adapters/frameworks/langgraph/handoff.py new file mode 100644 index 00000000..68424f96 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/handoff.py @@ -0,0 +1,385 @@ +""" +STRATIX LangGraph Handoff Detection + +Detects and traces agent handoffs in multi-agent LangGraph workflows. +""" + +from __future__ import annotations + +import time +import logging +from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from collections.abc import Callable + +if TYPE_CHECKING: + from layerlens.instrument.adapters._base.adapter import BaseAdapter + +logger = logging.getLogger(__name__) + + +@dataclass +class AgentHandoff: + """Represents a handoff between agents.""" + + from_agent: str + to_agent: str + timestamp_ns: int + context: dict[str, Any] | None = None + reason: str | None = None + + +class HandoffDetector: + """ + Detects agent handoffs in LangGraph multi-agent workflows. + + Handoffs occur when: + - A supervisor routes to a different agent + - Control transfers between agent nodes + - An agent explicitly delegates to another + + Usage: + detector = HandoffDetector(stratix_instance) + + # Register agents + detector.register_agent("researcher") + detector.register_agent("writer") + + # Check for handoff + if detector.is_handoff("researcher", "writer", state): + detector.emit_handoff("researcher", "writer", state) + """ + + def __init__( + self, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + ) -> None: + """ + Initialize the handoff detector. + + Args: + stratix_instance: STRATIX SDK instance (legacy) + adapter: BaseAdapter instance (new-style) + """ + self._stratix = stratix_instance + self._adapter = adapter + self._registered_agents: set[str] = set() + self._current_agent: str | None = None + self._handoffs: list[AgentHandoff] = [] + + def register_agent(self, agent_name: str) -> None: + """ + Register an agent for handoff tracking. + + Args: + agent_name: Name of the agent + """ + self._registered_agents.add(agent_name) + + def register_agents(self, *agent_names: str) -> None: + """ + Register multiple agents for handoff tracking. + + Args: + *agent_names: Names of agents + """ + for name in agent_names: + self._registered_agents.add(name) + + def set_current_agent(self, agent_name: str) -> None: + """ + Set the currently active agent. + + Args: + agent_name: Name of the current agent + """ + self._current_agent = agent_name + + def is_handoff( + self, + from_agent: str, + to_agent: str, + state: dict[str, Any] | None = None, + ) -> bool: + """ + Check if this represents a handoff. + + Args: + from_agent: Source agent + to_agent: Destination agent + state: Current state (optional) + + Returns: + True if this is a handoff + """ + # Different agents = handoff + return from_agent != to_agent + + def detect_handoff( + self, + next_agent: str, + state: dict[str, Any] | None = None, + ) -> AgentHandoff | None: + """ + Detect if transitioning to next_agent is a handoff. + + Args: + next_agent: The next agent to execute + state: Current state + + Returns: + AgentHandoff if detected, None otherwise + """ + if self._current_agent and self._current_agent != next_agent: + handoff = AgentHandoff( + from_agent=self._current_agent, + to_agent=next_agent, + timestamp_ns=time.time_ns(), + context=self._extract_context(state) if state else None, + ) + self._handoffs.append(handoff) + self._current_agent = next_agent + self._emit_handoff(handoff) + return handoff + + self._current_agent = next_agent + return None + + def emit_handoff( + self, + from_agent: str, + to_agent: str, + state: dict[str, Any] | None = None, + reason: str | None = None, + ) -> AgentHandoff: + """ + Explicitly emit a handoff event. + + Args: + from_agent: Source agent + to_agent: Destination agent + state: Current state + reason: Reason for handoff + + Returns: + Created AgentHandoff + """ + handoff = AgentHandoff( + from_agent=from_agent, + to_agent=to_agent, + timestamp_ns=time.time_ns(), + context=self._extract_context(state) if state else None, + reason=reason, + ) + self._handoffs.append(handoff) + self._current_agent = to_agent + self._emit_handoff(handoff) + return handoff + + def _extract_context(self, state: dict[str, Any]) -> dict[str, Any]: + """Extract relevant context from state for handoff tracking.""" + context = {} + + # Extract common handoff-related state keys + for key in ["task", "current_task", "objective", "query", "messages"]: + if key in state: + value = state[key] + # Truncate long values + if isinstance(value, str) and len(value) > 500: + context[key] = value[:500] + "..." + elif isinstance(value, list) and len(value) > 10: + context[key] = f"[{len(value)} items]" + else: + context[key] = value + + return context + + def _emit_handoff(self, handoff: AgentHandoff) -> None: + """Emit agent.handoff event via adapter (preferred) or legacy path.""" + payload_dict = { + "from_agent": handoff.from_agent, + "to_agent": handoff.to_agent, + "timestamp_ns": handoff.timestamp_ns, + "context": handoff.context, + "reason": handoff.reason, + } + + # New-style: route through adapter.emit_event + if self._adapter is not None: + try: + import json + import hashlib + + from layerlens.instrument._vendored.events import AgentHandoffEvent + + context_str = json.dumps(handoff.context or {}, sort_keys=True) + context_hash = "sha256:" + hashlib.sha256(context_str.encode()).hexdigest() + typed_payload = AgentHandoffEvent.create( + from_agent=handoff.from_agent, + to_agent=handoff.to_agent, + handoff_context_hash=context_hash, + ) + self._adapter.emit_event(typed_payload) + return + except Exception: + logger.debug("Typed event emission failed, falling back to legacy", exc_info=True) + + # Legacy fallback + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit("agent.handoff", payload_dict) + + +def detect_handoff( + from_agent: str, + to_agent: str, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + state: dict[str, Any] | None = None, + reason: str | None = None, +) -> AgentHandoff | None: + """ + Utility function to detect and emit a handoff event. + + Args: + from_agent: Source agent + to_agent: Destination agent + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + state: Current state + reason: Reason for handoff + + Returns: + AgentHandoff if detected, None if same agent + """ + if from_agent == to_agent: + return None + + detector = HandoffDetector(stratix_instance, adapter=adapter) + return detector.emit_handoff(from_agent, to_agent, state, reason) + + +class SupervisorHandoffTracker: + """ + Tracks handoffs in supervisor-style multi-agent architectures. + + In a supervisor architecture, a supervisor agent routes tasks + to worker agents. This tracker monitors these transitions. + + Usage: + tracker = SupervisorHandoffTracker(stratix_instance) + + # In supervisor node + def supervisor(state): + next_agent = decide_next_agent(state) + tracker.route_to(next_agent, state) + return {"next": next_agent} + """ + + def __init__( + self, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + supervisor_name: str = "supervisor", + ) -> None: + """ + Initialize the supervisor tracker. + + Args: + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + supervisor_name: Name of the supervisor agent + """ + self._detector = HandoffDetector(stratix_instance, adapter=adapter) + self._supervisor_name = supervisor_name + self._detector.register_agent(supervisor_name) + self._detector.set_current_agent(supervisor_name) + self._last_worker: str | None = None + + def register_worker(self, worker_name: str) -> None: + """ + Register a worker agent. + + Args: + worker_name: Name of the worker agent + """ + self._detector.register_agent(worker_name) + + def route_to( + self, + worker_name: str, + state: dict[str, Any] | None = None, + reason: str | None = None, + ) -> AgentHandoff: + """ + Track routing from supervisor to worker. + + Args: + worker_name: Worker to route to + state: Current state + reason: Reason for routing decision + + Returns: + AgentHandoff event + """ + from_agent = self._last_worker or self._supervisor_name + handoff = self._detector.emit_handoff( + from_agent=from_agent, + to_agent=worker_name, + state=state, + reason=reason or f"Supervisor routed to {worker_name}", + ) + self._last_worker = worker_name + return handoff + + def return_to_supervisor( + self, + state: dict[str, Any] | None = None, + reason: str | None = None, + ) -> AgentHandoff | None: + """ + Track return from worker to supervisor. + + Args: + state: Current state + reason: Reason for return + + Returns: + AgentHandoff event or None if already at supervisor + """ + if self._last_worker: + handoff = self._detector.emit_handoff( + from_agent=self._last_worker, + to_agent=self._supervisor_name, + state=state, + reason=reason or "Worker completed, returning to supervisor", + ) + self._last_worker = None + return handoff + return None + + +def create_handoff_aware_router( + route_func: Callable[[dict[str, Any]], str], + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, +) -> Callable[[dict[str, Any]], dict[str, Any]]: + """ + Create a router function that tracks handoffs. + + Args: + route_func: Function that takes state and returns next agent name + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + + Returns: + Router function that also emits handoff events + """ + detector = HandoffDetector(stratix_instance, adapter=adapter) + + def router(state: dict[str, Any]) -> dict[str, Any]: + next_agent = route_func(state) + detector.detect_handoff(next_agent, state) + return {"next": next_agent} + + return router diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/lifecycle.py b/src/layerlens/instrument/adapters/frameworks/langgraph/lifecycle.py new file mode 100644 index 00000000..e0fbb915 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/lifecycle.py @@ -0,0 +1,533 @@ +""" +STRATIX LangGraph Lifecycle Hooks + +Provides graph start/end hooks for STRATIX tracing. +""" + +from __future__ import annotations + +import time +import uuid +from typing import TYPE_CHECKING, Any, TypeVar +from dataclasses import field, dataclass + +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters._base.pydantic_compat import PydanticCompat +from layerlens.instrument.adapters.frameworks.langgraph.state import LangGraphStateAdapter + +if TYPE_CHECKING: + from layerlens.instrument.adapters.frameworks.langgraph.handoff import HandoffDetector + + +StateT = TypeVar("StateT") +GraphT = TypeVar("GraphT") + + +@dataclass +class GraphExecution: + """Represents a single graph execution.""" + + graph_id: str + execution_id: str + start_time_ns: int + end_time_ns: int | None = None + initial_state_hash: str | None = None + final_state_hash: str | None = None + node_executions: list[dict[str, Any]] = field(default_factory=list) + error: str | None = None + + +class LayerLensLangGraphAdapter(BaseAdapter): + """ + Main adapter for integrating STRATIX with LangGraph. + + This adapter wraps LangGraph graphs to automatically emit STRATIX events + for graph execution, node transitions, and state changes. + + Supports both the new BaseAdapter interface and the legacy constructor + for backward compatibility. + + Usage (new): + from stratix import STRATIX + from layerlens.instrument.adapters.frameworks.langgraph import LayerLensLangGraphAdapter + + stratix = STRATIX(policy_ref="my-policy") + adapter = LayerLensLangGraphAdapter(stratix=stratix) + adapter.connect() + traced_graph = adapter.wrap_graph(my_graph) + result = traced_graph.invoke(initial_state) + + Usage (legacy — still supported): + adapter = LayerLensLangGraphAdapter(stratix_instance=stratix) + traced_graph = adapter.wrap_graph(my_graph) + """ + + FRAMEWORK = "langgraph" + VERSION = "0.1.0" + # LangGraph >=0.2 (pyproject pin: langgraph>=0.2,<0.4) depends on + # langchain-core>=0.2 which is Pydantic v2 only. LangGraph's own + # state schema (StateGraph, MessagesState) uses v2-style typing. + requires_pydantic = PydanticCompat.V2_ONLY + + def __init__( + self, + # New-style params + stratix: Any | None = None, + capture_config: CaptureConfig | None = None, + handoff_detector: HandoffDetector | None = None, + # Legacy params (backward compat) + stratix_instance: Any | None = None, + state_adapter: LangGraphStateAdapter | None = None, + emit_environment_config: bool = True, + emit_agent_code: bool = False, + ) -> None: + """ + Initialize the LangGraph adapter. + + Accepts both new-style (stratix, capture_config) and legacy-style + (stratix_instance, boolean flags) parameters. When legacy params are + provided, they are mapped to CaptureConfig equivalents. + + Args: + stratix: STRATIX SDK instance (new-style) + capture_config: CaptureConfig (new-style) + handoff_detector: HandoffDetector for automatic handoff detection + during node transitions (optional) + stratix_instance: STRATIX SDK instance (legacy) + state_adapter: Custom state adapter (uses default if not provided) + emit_environment_config: Whether to emit environment.config (legacy) + emit_agent_code: Whether to emit agent.code (legacy) + """ + # Resolve STRATIX instance: new-style takes priority + resolved_stratix = stratix or stratix_instance + + # Map legacy booleans → CaptureConfig when any flag differs from default + if capture_config is None: + any_legacy = not emit_environment_config or emit_agent_code + if any_legacy or stratix_instance is not None: + capture_config = CaptureConfig( + l4a_environment_config=emit_environment_config, + l2_agent_code=emit_agent_code, + ) + + super().__init__(stratix=resolved_stratix, capture_config=capture_config) + + self._state_adapter = state_adapter or LangGraphStateAdapter() + self._executions: list[GraphExecution] = [] + self._handoff_detector: HandoffDetector | None = handoff_detector + + # Legacy compat: keep booleans accessible for code that reads them + self._emit_environment_config = emit_environment_config + self._emit_agent_code = emit_agent_code + + # --- BaseAdapter lifecycle --- + + def connect(self) -> None: + """Verify LangGraph is importable and mark as connected.""" + try: + import langgraph # type: ignore[import-not-found,unused-ignore] # noqa: F401 + + self._connected = True + self._status = AdapterStatus.HEALTHY + except ImportError: + # Still usable without LangGraph installed (for mock/test use) + self._connected = True + self._status = AdapterStatus.HEALTHY + + def disconnect(self) -> None: + """Flush and disconnect.""" + self._connected = False + self._status = AdapterStatus.DISCONNECTED + + def health_check(self) -> AdapterHealth: + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + framework_version=self._detect_framework_version(), + adapter_version=self.VERSION, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name="LayerLensLangGraphAdapter", + version=self.VERSION, + framework=self.FRAMEWORK, + framework_version=self._detect_framework_version(), + capabilities=[ + AdapterCapability.TRACE_TOOLS, + AdapterCapability.TRACE_MODELS, + AdapterCapability.TRACE_STATE, + AdapterCapability.TRACE_HANDOFFS, + AdapterCapability.REPLAY, + ], + description="LayerLens adapter for LangGraph agent framework", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + trace_id = str(uuid.uuid4()) + return ReplayableTrace( + adapter_name="LayerLensLangGraphAdapter", + framework=self.FRAMEWORK, + trace_id=trace_id, + events=list(self._trace_events), + state_snapshots=[], + config={ + "capture_config": self._capture_config.model_dump(), + }, + ) + + # --- Handoff detection --- + + def set_handoff_detector(self, detector: HandoffDetector) -> None: + """ + Attach a HandoffDetector to this adapter. + + When set, ``on_node_start`` will automatically call + ``detector.detect_handoff(node_name, state)`` on every + node transition, emitting handoff events when the active + agent changes. + + Args: + detector: HandoffDetector instance (should already have + agents registered via ``register_agent`` / + ``register_agents``) + """ + self._handoff_detector = detector + + @property + def handoff_detector(self) -> HandoffDetector | None: + """Return the attached HandoffDetector, or None.""" + return self._handoff_detector + + # --- Graph wrapping --- + + def wrap_graph(self, graph: GraphT) -> GraphT: + """ + Wrap a LangGraph compiled graph with STRATIX tracing. + + Args: + graph: Compiled LangGraph graph + + Returns: + Wrapped graph with same interface + """ + return _TracedGraph( # type: ignore[return-value] + graph=graph, + adapter=self, + state_adapter=self._state_adapter, + ) + + # --- Lifecycle hooks --- + + def on_graph_start( + self, + graph_id: str, + execution_id: str, + initial_state: Any, + config: dict[str, Any] | None = None, + ) -> GraphExecution: + """ + Handle graph execution start. + + Emits: + - environment.config (if enabled) + - agent.input + + Args: + graph_id: Identifier for the graph + execution_id: Unique execution ID + initial_state: Initial graph state + config: Graph execution config + + Returns: + GraphExecution tracking object + """ + execution = GraphExecution( + graph_id=graph_id, + execution_id=execution_id, + start_time_ns=time.time_ns(), + initial_state_hash=self._state_adapter.get_hash(initial_state), + ) + self._executions.append(execution) + + # Emit environment config (gated by CaptureConfig inside emit_dict_event) + self.emit_dict_event( + "environment.config", + { + "framework": "langgraph", + "graph_id": graph_id, + "config": config, + }, + ) + + # Emit agent input + self.emit_dict_event( + "agent.input", + { + "graph_id": graph_id, + "execution_id": execution_id, + "initial_state": self._safe_serialize(initial_state), + }, + ) + + return execution + + def on_graph_end( + self, + execution: GraphExecution, + final_state: Any, + error: Exception | None = None, + ) -> None: + """ + Handle graph execution end. + + Emits: + - agent.output + - agent.state.change (if state changed) + + Args: + execution: Execution tracking object + final_state: Final graph state + error: Exception if execution failed + """ + execution.end_time_ns = time.time_ns() + execution.final_state_hash = self._state_adapter.get_hash(final_state) + + if error: + execution.error = str(error) + + # Emit agent output (gated by CaptureConfig inside emit_dict_event) + self.emit_dict_event( + "agent.output", + { + "graph_id": execution.graph_id, + "execution_id": execution.execution_id, + "final_state": self._safe_serialize(final_state), + "duration_ns": execution.end_time_ns - execution.start_time_ns, + "error": execution.error, + }, + ) + + # Emit state change if state changed (cross-cutting — always enabled) + if execution.initial_state_hash != execution.final_state_hash: + self.emit_dict_event( + "agent.state.change", + { + "graph_id": execution.graph_id, + "execution_id": execution.execution_id, + "before_hash": execution.initial_state_hash, + "after_hash": execution.final_state_hash, + }, + ) + + def on_node_start( + self, + execution: GraphExecution, + node_name: str, + state: Any, + ) -> dict[str, Any]: + """ + Handle node execution start. + + If a HandoffDetector is attached, automatically feeds the node + transition to it so that agent-to-agent handoffs are detected + and emitted. + + Args: + execution: Execution tracking object + node_name: Name of the node + state: Current state + + Returns: + Node execution context for tracking + """ + node_context = { + "node_name": node_name, + "start_time_ns": time.time_ns(), + "state_hash_before": self._state_adapter.get_hash(state), + } + + if self._handoff_detector is not None: + self._handoff_detector.detect_handoff( + node_name, + state if isinstance(state, dict) else None, + ) + + return node_context + + def on_node_end( + self, + execution: GraphExecution, + node_context: dict[str, Any], + state: Any, + error: Exception | None = None, + ) -> None: + """ + Handle node execution end. + + Emits: + - agent.state.change (if state changed at this node) + + Args: + execution: Execution tracking object + node_context: Node context from on_node_start + state: State after node execution + error: Exception if node failed + """ + node_context["end_time_ns"] = time.time_ns() + node_context["state_hash_after"] = self._state_adapter.get_hash(state) + node_context["duration_ns"] = node_context["end_time_ns"] - node_context["start_time_ns"] + + if error: + node_context["error"] = str(error) + + execution.node_executions.append(node_context) + + # Emit state change if node modified state (cross-cutting — always enabled) + if node_context["state_hash_before"] != node_context["state_hash_after"]: + self.emit_dict_event( + "agent.state.change", + { + "graph_id": execution.graph_id, + "execution_id": execution.execution_id, + "node_name": node_context["node_name"], + "before_hash": node_context["state_hash_before"], + "after_hash": node_context["state_hash_after"], + }, + ) + + # --- Internal helpers --- + + def _safe_serialize(self, value: Any) -> Any: + """Safely serialize a value for events.""" + try: + if hasattr(value, "dict"): + return value.dict() + elif isinstance(value, dict): + return dict(value) + else: + return str(value) + except Exception: + return str(value) + + @staticmethod + def _detect_framework_version() -> str | None: + try: + import langgraph # type: ignore[import-not-found,unused-ignore] + + return getattr(langgraph, "__version__", None) + except ImportError: + return None + + +class _TracedGraph: + """ + Wrapper around a LangGraph compiled graph that adds STRATIX tracing. + """ + + def __init__( + self, + graph: Any, + adapter: LayerLensLangGraphAdapter, + state_adapter: LangGraphStateAdapter, + ) -> None: + self._graph = graph + self._adapter = adapter + self._state_adapter = state_adapter + self._execution_count = 0 + + def invoke(self, state: Any, config: dict[str, Any] | None = None) -> Any: + """ + Invoke the graph with tracing. + + Args: + state: Initial state + config: Execution config + + Returns: + Final state + """ + self._execution_count += 1 + graph_id = self._get_graph_id() + execution_id = f"{graph_id}:{self._execution_count}" + + # Start tracking + execution = self._adapter.on_graph_start( + graph_id=graph_id, + execution_id=execution_id, + initial_state=state, + config=config, + ) + + try: + # Execute the actual graph + result = self._graph.invoke(state, config) + + # End tracking + self._adapter.on_graph_end(execution, result) + + return result + + except Exception as e: + # End tracking with error + self._adapter.on_graph_end(execution, state, error=e) + raise + + async def ainvoke(self, state: Any, config: dict[str, Any] | None = None) -> Any: + """ + Async invoke the graph with tracing. + + Args: + state: Initial state + config: Execution config + + Returns: + Final state + """ + self._execution_count += 1 + graph_id = self._get_graph_id() + execution_id = f"{graph_id}:{self._execution_count}" + + # Start tracking + execution = self._adapter.on_graph_start( + graph_id=graph_id, + execution_id=execution_id, + initial_state=state, + config=config, + ) + + try: + # Execute the actual graph + result = await self._graph.ainvoke(state, config) + + # End tracking + self._adapter.on_graph_end(execution, result) + + return result + + except Exception as e: + # End tracking with error + self._adapter.on_graph_end(execution, state, error=e) + raise + + def _get_graph_id(self) -> str: + """Get the graph identifier.""" + if hasattr(self._graph, "name"): + return self._graph.name # type: ignore[no-any-return] + elif hasattr(self._graph, "__class__"): + return self._graph.__class__.__name__ # type: ignore[no-any-return] + return "langgraph" + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to underlying graph.""" + return getattr(self._graph, name) diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/llm.py b/src/layerlens/instrument/adapters/frameworks/langgraph/llm.py new file mode 100644 index 00000000..30b0e3c3 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/llm.py @@ -0,0 +1,411 @@ +""" +STRATIX LangGraph LLM Wrapper + +Wraps LLM calls to emit model.invoke (L3) events. +""" + +from __future__ import annotations + +import time +import logging +from typing import TYPE_CHECKING, Any, TypeVar +from dataclasses import dataclass + +if TYPE_CHECKING: + from layerlens.instrument.adapters._base.adapter import BaseAdapter + +logger = logging.getLogger(__name__) + + +MessageT = TypeVar("MessageT") + + +@dataclass +class LLMInvocation: + """Tracks a single LLM invocation.""" + + model: str + provider: str + start_time_ns: int + end_time_ns: int | None = None + input_messages: list[Any] | None = None + output_message: Any | None = None + token_usage: dict[str, int] | None = None + error: str | None = None + + +class TracedLLM: + """ + Wrapper around an LLM that emits model.invoke events. + + Compatible with LangChain/LangGraph chat models. + + Usage: + from langchain_openai import ChatOpenAI # type: ignore[import-untyped,unused-ignore] + + llm = ChatOpenAI(model="gpt-4") + traced_llm = TracedLLM(llm, stratix_instance=stratix) + + # Use as normal + response = traced_llm.invoke(messages) + """ + + def __init__( + self, + llm: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + model_name: str | None = None, + provider: str | None = None, + ) -> None: + """ + Initialize the traced LLM. + + Args: + llm: The underlying LLM instance + stratix_instance: STRATIX SDK instance (legacy) + adapter: BaseAdapter instance (new-style) + model_name: Model name override (auto-detected if not provided) + provider: Provider name override (auto-detected if not provided) + """ + self._llm = llm + self._stratix = stratix_instance + self._adapter = adapter + self._model_name = model_name or self._detect_model_name() + self._provider = provider or self._detect_provider() + self._invocations: list[LLMInvocation] = [] + + def invoke(self, messages: Any, **kwargs: Any) -> Any: + """ + Invoke the LLM with tracing. + + Args: + messages: Input messages + **kwargs: Additional arguments + + Returns: + LLM response + """ + invocation = LLMInvocation( + model=self._model_name, + provider=self._provider, + start_time_ns=time.time_ns(), + input_messages=self._serialize_messages(messages), + ) + self._invocations.append(invocation) + + try: + response = self._llm.invoke(messages, **kwargs) + invocation.end_time_ns = time.time_ns() + invocation.output_message = self._serialize_response(response) + invocation.token_usage = self._extract_token_usage(response) + self._emit_model_invoke(invocation) + return response + + except Exception as e: + invocation.end_time_ns = time.time_ns() + invocation.error = str(e) + self._emit_model_invoke(invocation) + raise + + async def ainvoke(self, messages: Any, **kwargs: Any) -> Any: + """ + Async invoke the LLM with tracing. + + Args: + messages: Input messages + **kwargs: Additional arguments + + Returns: + LLM response + """ + invocation = LLMInvocation( + model=self._model_name, + provider=self._provider, + start_time_ns=time.time_ns(), + input_messages=self._serialize_messages(messages), + ) + self._invocations.append(invocation) + + try: + response = await self._llm.ainvoke(messages, **kwargs) + invocation.end_time_ns = time.time_ns() + invocation.output_message = self._serialize_response(response) + invocation.token_usage = self._extract_token_usage(response) + self._emit_model_invoke(invocation) + return response + + except Exception as e: + invocation.end_time_ns = time.time_ns() + invocation.error = str(e) + self._emit_model_invoke(invocation) + raise + + def stream(self, messages: Any, **kwargs: Any) -> Any: + """ + Stream the LLM response with tracing. + + Note: For streaming, we emit the event after the stream is consumed. + + Args: + messages: Input messages + **kwargs: Additional arguments + + Yields: + Response chunks + """ + invocation = LLMInvocation( + model=self._model_name, + provider=self._provider, + start_time_ns=time.time_ns(), + input_messages=self._serialize_messages(messages), + ) + self._invocations.append(invocation) + + try: + chunks = [] + for chunk in self._llm.stream(messages, **kwargs): + chunks.append(chunk) + yield chunk + + invocation.end_time_ns = time.time_ns() + invocation.output_message = self._combine_chunks(chunks) + self._emit_model_invoke(invocation) + + except Exception as e: + invocation.end_time_ns = time.time_ns() + invocation.error = str(e) + self._emit_model_invoke(invocation) + raise + + def _detect_model_name(self) -> str: + """Auto-detect model name from LLM instance.""" + # Try common attribute names + for attr in ["model_name", "model", "_model_name", "model_id"]: + if hasattr(self._llm, attr): + value = getattr(self._llm, attr) + if value: + return str(value) + return "unknown" + + def _detect_provider(self) -> str: + """Auto-detect provider from LLM instance.""" + class_name = self._llm.__class__.__name__.lower() + + if "openai" in class_name: + return "openai" + elif "anthropic" in class_name or "claude" in class_name: + return "anthropic" + elif "google" in class_name or "gemini" in class_name: + return "google" + elif "cohere" in class_name: + return "cohere" + elif "huggingface" in class_name: + return "huggingface" + + return "unknown" + + def _serialize_messages(self, messages: Any) -> list[dict[str, Any]]: + """Serialize input messages.""" + if isinstance(messages, str): + return [{"role": "user", "content": messages}] + + result = [] + if isinstance(messages, list): + for msg in messages: + if isinstance(msg, dict): + result.append(msg) + elif hasattr(msg, "content") and hasattr(msg, "type"): + result.append( + { + "role": getattr(msg, "type", "unknown"), + "content": str(msg.content), + } + ) + else: + result.append({"content": str(msg)}) + + return result + + def _serialize_response(self, response: Any) -> dict[str, Any]: + """Serialize LLM response.""" + if isinstance(response, str): + return {"content": response} + + if hasattr(response, "content"): + result = {"content": str(response.content)} + if hasattr(response, "type"): + result["role"] = response.type + return result + + return {"content": str(response)} + + def _extract_token_usage(self, response: Any) -> dict[str, int] | None: + """Extract token usage from response.""" + # Try response_metadata (LangChain style) + if hasattr(response, "response_metadata"): + metadata = response.response_metadata + if isinstance(metadata, dict) and "usage" in metadata: + return metadata["usage"] # type: ignore[no-any-return] + + # Try usage_metadata + if hasattr(response, "usage_metadata"): + return response.usage_metadata # type: ignore[no-any-return] + + return None + + def _combine_chunks(self, chunks: list[Any]) -> dict[str, Any]: + """Combine streaming chunks into single response.""" + content_parts = [] + for chunk in chunks: + if hasattr(chunk, "content"): + content_parts.append(str(chunk.content)) + elif isinstance(chunk, str): + content_parts.append(chunk) + + return {"content": "".join(content_parts)} + + def _emit_model_invoke(self, invocation: LLMInvocation) -> None: + """Emit model.invoke event via adapter (preferred) or legacy path.""" + duration_ns = (invocation.end_time_ns or 0) - invocation.start_time_ns + + # New-style: route through adapter.emit_event + if self._adapter is not None: + try: + from layerlens.instrument._vendored.events import ModelInvokeEvent + + typed_payload = ModelInvokeEvent.create( # type: ignore[call-arg,unused-ignore] + model_name=invocation.model, + provider=invocation.provider, + input_messages=invocation.input_messages or [], + output_message=invocation.output_message, + token_usage=invocation.token_usage, + duration_ns=duration_ns, + error=invocation.error, + ) + self._adapter.emit_event(typed_payload) + return + except Exception: + logger.debug("Typed event emission failed, falling back to legacy", exc_info=True) + + # Legacy fallback + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit( + "model.invoke", + { + "model": invocation.model, + "provider": invocation.provider, + "input_messages": invocation.input_messages, + "output_message": invocation.output_message, + "token_usage": invocation.token_usage, + "duration_ns": duration_ns, + "error": invocation.error, + }, + ) + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to underlying LLM.""" + return getattr(self._llm, name) + + +def wrap_llm_for_langgraph( + llm: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + model_name: str | None = None, + provider: str | None = None, +) -> TracedLLM: + """ + Wrap an LLM for use in LangGraph with STRATIX tracing. + + Usage: + from langchain_openai import ChatOpenAI # type: ignore[import-untyped,unused-ignore] + + llm = ChatOpenAI(model="gpt-4") + traced_llm = wrap_llm_for_langgraph(llm, stratix_instance=stratix) + + Args: + llm: LLM instance to wrap + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + model_name: Model name override + provider: Provider name override + + Returns: + TracedLLM wrapper + """ + return TracedLLM( + llm=llm, + stratix_instance=stratix_instance, + adapter=adapter, + model_name=model_name, + provider=provider, + ) + + +class LLMCallNode: + """ + A LangGraph node that wraps an LLM call with tracing. + + Usage: + llm_node = LLMCallNode( + llm=ChatOpenAI(), + stratix_instance=stratix, + messages_key="messages", + ) + + graph.add_node("llm", llm_node) + """ + + def __init__( + self, + llm: Any, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + messages_key: str = "messages", + response_key: str = "messages", + ) -> None: + """ + Initialize the LLM call node. + + Args: + llm: LLM instance + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + messages_key: Key in state containing messages + response_key: Key in state to add response to + """ + self._traced_llm = TracedLLM(llm, stratix_instance, adapter=adapter) + self._messages_key = messages_key + self._response_key = response_key + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """ + Execute the LLM node. + + Args: + state: LangGraph state + + Returns: + Updated state with LLM response + """ + messages = state.get(self._messages_key, []) + response = self._traced_llm.invoke(messages) + + # Return state update + return {self._response_key: [response]} + + async def __acall__(self, state: dict[str, Any]) -> dict[str, Any]: + """ + Async execute the LLM node. + + Args: + state: LangGraph state + + Returns: + Updated state with LLM response + """ + messages = state.get(self._messages_key, []) + response = await self._traced_llm.ainvoke(messages) + + return {self._response_key: [response]} diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/nodes.py b/src/layerlens/instrument/adapters/frameworks/langgraph/nodes.py new file mode 100644 index 00000000..1fa24ab1 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/nodes.py @@ -0,0 +1,286 @@ +""" +STRATIX LangGraph Node Tracing + +Provides node entry/exit hooks and decorators for tracing node execution. +""" + +from __future__ import annotations + +import time +import logging +from typing import TYPE_CHECKING, Any, TypeVar +from functools import wraps +from dataclasses import dataclass +from collections.abc import Callable + +from layerlens.instrument.adapters.frameworks.langgraph.state import LangGraphStateAdapter + +if TYPE_CHECKING: + from layerlens.instrument.adapters._base.adapter import BaseAdapter + +logger = logging.getLogger(__name__) + + +StateT = TypeVar("StateT") +NodeFunc = Callable[[StateT], StateT] + + +@dataclass +class NodeExecution: + """Tracks a single node execution.""" + + node_name: str + start_time_ns: int + end_time_ns: int | None = None + state_hash_before: str | None = None + state_hash_after: str | None = None + error: str | None = None + + +class NodeTracer: + """ + Tracer for LangGraph node executions. + + Provides hooks for node entry/exit and automatic state change detection. + + Usage: + tracer = NodeTracer(stratix_instance) + + # Manual tracking + with tracer.trace_node("my_node", state): + # Node logic here + new_state = process(state) + + # Or use the decorator + @tracer.decorate + def my_node(state): + return process(state) + """ + + def __init__( + self, + stratix_instance: Any = None, + state_adapter: LangGraphStateAdapter | None = None, + adapter: BaseAdapter | None = None, + ) -> None: + """ + Initialize the node tracer. + + Args: + stratix_instance: STRATIX SDK instance (legacy) + state_adapter: State adapter for change detection + adapter: BaseAdapter instance (new-style) + """ + self._stratix = stratix_instance + self._adapter = adapter + self._state_adapter = state_adapter or LangGraphStateAdapter() + self._executions: list[NodeExecution] = [] + + def trace_node(self, node_name: str, state: Any) -> _NodeContext: + """ + Create a context manager for tracing a node. + + Args: + node_name: Name of the node + state: Current state + + Returns: + Context manager for node tracing + """ + return _NodeContext( + tracer=self, + node_name=node_name, + state=state, + ) + + def decorate(self, func: NodeFunc) -> NodeFunc: # type: ignore[type-arg] + """ + Decorate a node function with tracing. + + Args: + func: Node function + + Returns: + Decorated function + """ + node_name = func.__name__ + + @wraps(func) + def wrapper(state: StateT) -> StateT: + with self.trace_node(node_name, state) as ctx: + result = func(state) + ctx.set_result(result) + return result # type: ignore[no-any-return] + + return wrapper + + def on_node_enter(self, node_name: str, state: Any) -> NodeExecution: + """ + Called when entering a node. + + Emits agent.state.change event tracking entry. + + Args: + node_name: Name of the node + state: Current state + + Returns: + NodeExecution tracking object + """ + execution = NodeExecution( + node_name=node_name, + start_time_ns=time.time_ns(), + state_hash_before=self._state_adapter.get_hash(state), + ) + self._executions.append(execution) + + return execution + + def on_node_exit( + self, + execution: NodeExecution, + state: Any, + error: Exception | None = None, + ) -> None: + """ + Called when exiting a node. + + Emits agent.state.change event if state changed. + + Args: + execution: Execution tracking object + state: State after node execution + error: Exception if node failed + """ + execution.end_time_ns = time.time_ns() + execution.state_hash_after = self._state_adapter.get_hash(state) + + if error: + execution.error = str(error) + + # Emit state change event if state changed + if execution.state_hash_before != execution.state_hash_after: + self._emit_state_change(execution) + + def _emit_state_change(self, execution: NodeExecution) -> None: + """Emit state change event via adapter (preferred) or legacy path.""" + duration_ns = (execution.end_time_ns or 0) - execution.start_time_ns + + # New-style: route through adapter.emit_event + if self._adapter is not None: + try: + from layerlens.instrument._vendored.events import ( + StateType, + AgentStateChangeEvent, + ) + + typed_payload = AgentStateChangeEvent.create( + state_type=StateType.INTERNAL, + before_hash=execution.state_hash_before or "sha256:" + "0" * 64, + after_hash=execution.state_hash_after or "sha256:" + "0" * 64, + ) + self._adapter.emit_event(typed_payload) + return + except Exception: + logger.debug("Typed event emission failed, falling back to legacy", exc_info=True) + + # Legacy fallback + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit( + "agent.state.change", + { + "node_name": execution.node_name, + "before_hash": execution.state_hash_before, + "after_hash": execution.state_hash_after, + "duration_ns": duration_ns, + }, + ) + + +class _NodeContext: + """Context manager for node tracing.""" + + def __init__(self, tracer: NodeTracer, node_name: str, state: Any) -> None: + self._tracer = tracer + self._node_name = node_name + self._state = state + self._result_state: Any = None + self._execution: NodeExecution | None = None + + def __enter__(self) -> _NodeContext: + self._execution = self._tracer.on_node_enter(self._node_name, self._state) + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._execution: + # Use result state if set, otherwise use original state + final_state = self._result_state if self._result_state is not None else self._state + error = exc_val if exc_val else None + self._tracer.on_node_exit(self._execution, final_state, error) + + def set_result(self, state: Any) -> None: + """Set the result state for tracking.""" + self._result_state = state + + +def trace_node( + stratix_instance: Any = None, + state_adapter: LangGraphStateAdapter | None = None, + adapter: BaseAdapter | None = None, +) -> Callable[[NodeFunc], NodeFunc]: # type: ignore[type-arg] + """ + Decorator factory for tracing node functions. + + Usage: + @trace_node(stratix) + def my_node(state): + return new_state + + Args: + stratix_instance: STRATIX SDK instance + state_adapter: State adapter for change detection + adapter: BaseAdapter instance (new-style) + + Returns: + Decorator function + """ + tracer = NodeTracer(stratix_instance, state_adapter, adapter=adapter) + + def decorator(func: NodeFunc) -> NodeFunc: # type: ignore[type-arg] + return tracer.decorate(func) + + return decorator + + +def create_traced_node( + func: NodeFunc, # type: ignore[type-arg] + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + node_name: str | None = None, +) -> NodeFunc: # type: ignore[type-arg] + """ + Create a traced version of a node function. + + This is useful when you want to trace existing functions without + modifying them. + + Args: + func: Original node function + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + node_name: Name to use for tracing (defaults to function name) + + Returns: + Traced node function + """ + tracer = NodeTracer(stratix_instance, adapter=adapter) + name = node_name or func.__name__ + + @wraps(func) + def traced_func(state: Any) -> Any: + with tracer.trace_node(name, state) as ctx: + result = func(state) + ctx.set_result(result) + return result + + return traced_func diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/state.py b/src/layerlens/instrument/adapters/frameworks/langgraph/state.py new file mode 100644 index 00000000..c32717b2 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/state.py @@ -0,0 +1,251 @@ +""" +STRATIX LangGraph State Adapter + +Adapts LangGraph graph state for STRATIX state tracking. + +Note: This adapter is designed specifically for LangGraph state management +and doesn't extend the base StateAdapter which is designed for STRATIX core integration. +""" + +from __future__ import annotations + +import json +import hashlib +from typing import Any, TypeVar +from dataclasses import dataclass + +StateT = TypeVar("StateT") + + +@dataclass +class StateSnapshot: + """Snapshot of graph state at a point in time.""" + + state: dict[str, Any] + hash: str + timestamp_ns: int + + +class LangGraphStateAdapter: + """ + State adapter for LangGraph graph state. + + Captures state snapshots at node boundaries and detects mutations. + + Usage: + adapter = LangGraphStateAdapter() + + # Before node + before_snapshot = adapter.snapshot(state) + + # After node + after_snapshot = adapter.snapshot(state) + + # Check for changes + if adapter.has_changed(before_snapshot, after_snapshot): + changes = adapter.diff(before_snapshot, after_snapshot) + """ + + def __init__( + self, include_keys: list[str] | None = None, exclude_keys: list[str] | None = None + ) -> None: + """ + Initialize the state adapter. + + Args: + include_keys: Only track these keys (if specified) + exclude_keys: Exclude these keys from tracking + """ + self._include_keys = set(include_keys) if include_keys else None + self._exclude_keys = set(exclude_keys) if exclude_keys else set() + + def snapshot(self, state: Any) -> StateSnapshot: + """ + Create a snapshot of the current state. + + Args: + state: LangGraph state (typically a dict or TypedDict) + + Returns: + StateSnapshot with hash for comparison + """ + import time + + # Convert state to dictionary + state_dict = self._to_dict(state) + + # Filter keys if configured + filtered_state = self._filter_state(state_dict) + + # Compute hash + state_hash = self._compute_hash(filtered_state) + + return StateSnapshot( + state=filtered_state, + hash=state_hash, + timestamp_ns=time.time_ns(), + ) + + def has_changed(self, before: StateSnapshot, after: StateSnapshot) -> bool: + """ + Check if state has changed between snapshots. + + Args: + before: Snapshot before operation + after: Snapshot after operation + + Returns: + True if state changed + """ + return before.hash != after.hash + + def diff(self, before: StateSnapshot, after: StateSnapshot) -> dict[str, Any]: + """ + Compute the difference between two snapshots. + + Args: + before: Snapshot before operation + after: Snapshot after operation + + Returns: + Dictionary describing changes: + { + "added": {"key": value}, + "removed": {"key": value}, + "modified": {"key": {"before": old, "after": new}} + } + """ + added = {} + removed = {} + modified = {} + + before_keys = set(before.state.keys()) + after_keys = set(after.state.keys()) + + # Added keys + for key in after_keys - before_keys: + added[key] = after.state[key] + + # Removed keys + for key in before_keys - after_keys: + removed[key] = before.state[key] + + # Modified keys + for key in before_keys & after_keys: + if before.state[key] != after.state[key]: + modified[key] = { + "before": before.state[key], + "after": after.state[key], + } + + return { + "added": added, + "removed": removed, + "modified": modified, + } + + def get_hash(self, state: Any) -> str: + """ + Compute hash of state without creating full snapshot. + + Args: + state: LangGraph state + + Returns: + Hash string + """ + state_dict = self._to_dict(state) + filtered = self._filter_state(state_dict) + return self._compute_hash(filtered) + + def _to_dict(self, state: Any) -> dict[str, Any]: + """Convert state to dictionary.""" + if isinstance(state, dict): + return dict(state) + elif hasattr(state, "__dict__"): + return dict(state.__dict__) + elif hasattr(state, "_asdict"): # NamedTuple + return state._asdict() # type: ignore[no-any-return] + else: + # Try to treat as dict-like + try: + return dict(state) + except (TypeError, ValueError): + return {"__value__": state} + + def _filter_state(self, state: dict[str, Any]) -> dict[str, Any]: + """Apply include/exclude filters.""" + if self._include_keys: + state = {k: v for k, v in state.items() if k in self._include_keys} + + if self._exclude_keys: + state = {k: v for k, v in state.items() if k not in self._exclude_keys} + + return state + + def _compute_hash(self, state: dict[str, Any]) -> str: + """Compute SHA-256 hash of state.""" + # Canonical JSON serialization + try: + serialized = json.dumps(state, sort_keys=True, default=str) + except TypeError: + # Fallback for non-serializable objects + serialized = str(state) + + return hashlib.sha256(serialized.encode()).hexdigest() + + +class MessageListAdapter(LangGraphStateAdapter): + """ + Specialized adapter for LangGraph message-based state. + + LangGraph commonly uses a messages list in state. + This adapter optimizes tracking for message append patterns. + """ + + def __init__(self, message_key: str = "messages") -> None: + """ + Initialize the message list adapter. + + Args: + message_key: Key in state that contains messages list + """ + super().__init__() # Initialize parent with defaults + self._message_key = message_key + self._last_message_count = 0 + + def snapshot(self, state: Any) -> StateSnapshot: + """Create snapshot with message count optimization.""" + snapshot = LangGraphStateAdapter.snapshot(self, state) + + # Track message count for efficient change detection + state_dict = self._to_dict(state) + if self._message_key in state_dict: + messages = state_dict[self._message_key] + if isinstance(messages, list): + self._last_message_count = len(messages) + + return snapshot + + def get_new_messages(self, before: StateSnapshot, after: StateSnapshot) -> list[Any]: + """ + Get messages added between snapshots. + + Args: + before: Snapshot before + after: Snapshot after + + Returns: + List of new messages + """ + before_messages = before.state.get(self._message_key, []) + after_messages = after.state.get(self._message_key, []) + + if not isinstance(before_messages, list) or not isinstance(after_messages, list): + return [] + + # Assume messages are appended, not inserted + if len(after_messages) > len(before_messages): + return after_messages[len(before_messages) :] + + return [] diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph/tools.py b/src/layerlens/instrument/adapters/frameworks/langgraph/tools.py new file mode 100644 index 00000000..c0686359 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph/tools.py @@ -0,0 +1,345 @@ +""" +STRATIX LangGraph Tool Tracing + +Provides decorators and wrappers for tracing LangGraph tool nodes. +""" + +from __future__ import annotations + +import time +import logging +from typing import TYPE_CHECKING, Any, TypeVar +from functools import wraps +from dataclasses import dataclass +from collections.abc import Callable + +if TYPE_CHECKING: + from layerlens.instrument.adapters._base.adapter import BaseAdapter + +logger = logging.getLogger(__name__) + + +StateT = TypeVar("StateT") +ToolFunc = Callable[..., Any] + + +@dataclass +class ToolExecution: + """Tracks a single tool execution.""" + + tool_name: str + start_time_ns: int + end_time_ns: int | None = None + input_args: dict[str, Any] | None = None + output: Any | None = None + error: str | None = None + + +class ToolTracer: + """ + Tracer for LangGraph tool executions. + + Emits tool.call (L5a) events for each tool invocation. + + Usage: + tracer = ToolTracer(stratix_instance) + + @tracer.trace + def my_tool(query: str) -> str: + return search(query) + """ + + def __init__( + self, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + ) -> None: + """ + Initialize the tool tracer. + + Args: + stratix_instance: STRATIX SDK instance (legacy) + adapter: BaseAdapter instance (new-style). When provided, + typed event emission is used. + """ + self._stratix = stratix_instance + self._adapter = adapter + self._executions: list[ToolExecution] = [] + + def trace(self, func: ToolFunc) -> ToolFunc: + """ + Decorate a tool function with tracing. + + Emits tool.call event capturing input/output. + + Args: + func: Tool function + + Returns: + Decorated function + """ + tool_name = func.__name__ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + execution = ToolExecution( + tool_name=tool_name, + start_time_ns=time.time_ns(), + input_args=self._capture_input(args, kwargs), + ) + self._executions.append(execution) + + try: + result = func(*args, **kwargs) + execution.end_time_ns = time.time_ns() + execution.output = self._safe_output(result) + self._emit_tool_call(execution) + return result + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._emit_tool_call(execution) + raise + + return wrapper + + def trace_async(self, func: ToolFunc) -> ToolFunc: + """ + Decorate an async tool function with tracing. + + Args: + func: Async tool function + + Returns: + Decorated async function + """ + tool_name = func.__name__ + + @wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + execution = ToolExecution( + tool_name=tool_name, + start_time_ns=time.time_ns(), + input_args=self._capture_input(args, kwargs), + ) + self._executions.append(execution) + + try: + result = await func(*args, **kwargs) + execution.end_time_ns = time.time_ns() + execution.output = self._safe_output(result) + self._emit_tool_call(execution) + return result + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._emit_tool_call(execution) + raise + + return wrapper + + def _capture_input(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: + """Capture tool input arguments.""" + return { + "args": [self._safe_serialize(a) for a in args], + "kwargs": {k: self._safe_serialize(v) for k, v in kwargs.items()}, + } + + def _safe_serialize(self, value: Any) -> Any: + """Safely serialize a value.""" + try: + if isinstance(value, (str, int, float, bool, type(None))): + return value + elif isinstance(value, (list, tuple)): + return [self._safe_serialize(v) for v in value] + elif isinstance(value, dict): + return {k: self._safe_serialize(v) for k, v in value.items()} + else: + return str(value) + except Exception: + return "" + + def _safe_output(self, value: Any) -> Any: + """Safely capture output value.""" + return self._safe_serialize(value) + + def _emit_tool_call(self, execution: ToolExecution) -> None: + """Emit tool.call event via adapter (preferred) or legacy path.""" + duration_ns = (execution.end_time_ns or 0) - execution.start_time_ns + payload_dict = { + "tool_name": execution.tool_name, + "input": execution.input_args, + "output": execution.output, + "duration_ns": duration_ns, + "error": execution.error, + } + + # New-style: route through adapter.emit_event + if self._adapter is not None: + try: + from layerlens.instrument._vendored.events import ( + ToolCallEvent, + IntegrationType, + ) + + typed_payload = ToolCallEvent.create( # type: ignore[call-arg,unused-ignore] + tool_name=execution.tool_name, + integration_type=IntegrationType.LIBRARY, + input_data=execution.input_args or {}, + output_data=execution.output, + duration_ns=duration_ns, + error=execution.error, + ) + self._adapter.emit_event(typed_payload) + return + except Exception: + logger.debug("Typed event emission failed, falling back to legacy", exc_info=True) + + # Legacy fallback + if self._stratix and hasattr(self._stratix, "emit"): + self._stratix.emit("tool.call", payload_dict) + + +def trace_langgraph_tool( + func: ToolFunc | None = None, + *, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + tool_name: str | None = None, +) -> ToolFunc | Callable[[ToolFunc], ToolFunc]: + """ + Decorator for tracing LangGraph tool functions. + + Can be used with or without arguments: + + @trace_langgraph_tool + def my_tool(query: str) -> str: + ... + + @trace_langgraph_tool(stratix_instance=stratix) + def my_tool(query: str) -> str: + ... + + Args: + func: Tool function (when used without arguments) + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + tool_name: Custom name for the tool + + Returns: + Decorated function or decorator + """ + tracer = ToolTracer(stratix_instance, adapter=adapter) + + def decorator(f: ToolFunc) -> ToolFunc: + name = tool_name or f.__name__ + + @wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> Any: + execution = ToolExecution( + tool_name=name, + start_time_ns=time.time_ns(), + input_args=tracer._capture_input(args, kwargs), + ) + + try: + result = f(*args, **kwargs) + execution.end_time_ns = time.time_ns() + execution.output = tracer._safe_output(result) + tracer._emit_tool_call(execution) + return result + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + tracer._emit_tool_call(execution) + raise + + return wrapper + + if func is not None: + # Called without arguments: @trace_langgraph_tool + return decorator(func) + else: + # Called with arguments: @trace_langgraph_tool(...) + return decorator + + +class LangGraphToolNode: + """ + Wrapper for creating traced LangGraph tool nodes. + + This creates a node that wraps a tool function and automatically + emits tool.call events. + + Usage: + # Create a traced tool node + search_node = LangGraphToolNode( + tool_func=search_function, + stratix_instance=stratix, + ) + + # Use in graph + graph.add_node("search", search_node) + """ + + def __init__( + self, + tool_func: ToolFunc, + stratix_instance: Any = None, + adapter: BaseAdapter | None = None, + tool_name: str | None = None, + state_key: str | None = None, + ) -> None: + """ + Initialize the tool node. + + Args: + tool_func: The tool function to wrap + stratix_instance: STRATIX SDK instance + adapter: BaseAdapter instance (new-style) + tool_name: Name for the tool (defaults to function name) + state_key: Key in state to use as tool input (if None, uses full state) + """ + self._tool_func = tool_func + self._stratix = stratix_instance + self._tool_name = tool_name or tool_func.__name__ + self._state_key = state_key + self._tracer = ToolTracer(stratix_instance, adapter=adapter) + + def __call__(self, state: dict[str, Any]) -> dict[str, Any]: + """ + Execute the tool node. + + Args: + state: LangGraph state + + Returns: + Updated state + """ + # Get input from state + tool_input = state.get(self._state_key) if self._state_key else state + + execution = ToolExecution( + tool_name=self._tool_name, + start_time_ns=time.time_ns(), + input_args={"state_input": self._tracer._safe_serialize(tool_input)}, + ) + + try: + # Call the tool + result = self._tool_func(tool_input) + execution.end_time_ns = time.time_ns() + execution.output = self._tracer._safe_output(result) + self._tracer._emit_tool_call(execution) + + # Return updated state + return {"tool_output": result} + + except Exception as e: + execution.end_time_ns = time.time_ns() + execution.error = str(e) + self._tracer._emit_tool_call(execution) + raise diff --git a/tests/instrument/adapters/__init__.py b/tests/instrument/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/instrument/adapters/frameworks/__init__.py b/tests/instrument/adapters/frameworks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/instrument/adapters/frameworks/test_autogen_adapter.py b/tests/instrument/adapters/frameworks/test_autogen_adapter.py new file mode 100644 index 00000000..a3cde4cb --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_autogen_adapter.py @@ -0,0 +1,229 @@ +"""Unit tests for the AutoGen framework adapter. + +Mocked at the SDK shape level — no real ``autogen`` runtime needed. +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.frameworks.autogen import ( + ADAPTER_CLASS, + AutoGenAdapter, + instrument_agents, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +class _FakeAgent: + """Minimal duck-typed AutoGen ConversableAgent for tests.""" + + def __init__( + self, + name: str = "agent", + system_message: Any = None, + llm_config: Any = None, + ) -> None: + self.name = name + self.system_message = system_message + self.llm_config = llm_config + + def send(self, message: Any, recipient: Any, **kwargs: Any) -> Any: + return None + + def receive(self, message: Any, sender: Any, **kwargs: Any) -> Any: + return None + + def generate_reply(self, messages: Any = None, sender: Any = None, **kwargs: Any) -> Any: + return "reply" + + def execute_code_blocks(self, code_blocks: Any) -> Any: + return "exec result" + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is AutoGenAdapter + + +def test_lifecycle() -> None: + a = AutoGenAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + a.disconnect() + assert a.status == AdapterStatus.DISCONNECTED + + +def test_adapter_info_and_health() -> None: + a = AutoGenAdapter() + a.connect() + info = a.get_adapter_info() + assert info.framework == "autogen" + assert info.name == "AutoGenAdapter" + health = a.health_check() + assert health.framework_name == "autogen" + + +def test_connect_agents_wraps_methods_and_emits_config() -> None: + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + agent = _FakeAgent(name="alice", llm_config={"model": "gpt-5"}) + adapter.connect_agents(agent) + + # Methods replaced. + assert agent.send.__name__ == "traced_send" + assert agent.receive.__name__ == "traced_receive" + assert agent.generate_reply.__name__ == "traced_generate_reply" + assert agent.execute_code_blocks.__name__ == "traced_execute_code" + + # environment.config emitted once for this agent. + configs = [e for e in stratix.events if e["event_type"] == "environment.config"] + assert len(configs) == 1 + + adapter.disconnect() + # Original methods restored. + assert agent.send.__name__ == "send" + + +def test_connect_agents_idempotent() -> None: + """Calling connect_agents twice with the same agent does not double-wrap.""" + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + agent = _FakeAgent(name="alice") + adapter.connect_agents(agent) + adapter.connect_agents(agent) + + configs = [e for e in stratix.events if e["event_type"] == "environment.config"] + assert len(configs) == 1 + + +def test_on_send_emits_handoff() -> None: + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + sender = _FakeAgent(name="alice") + recipient = _FakeAgent(name="bob") + adapter.on_send(sender=sender, message="hi there", recipient=recipient) + + evt = next(e for e in stratix.events if e["event_type"] == "agent.handoff") + assert evt["payload"]["from_agent"] == "alice" + assert evt["payload"]["to_agent"] == "bob" + assert evt["payload"]["message_seq"] == 1 + + +def test_on_receive_emits_state_change() -> None: + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + receiver = _FakeAgent(name="bob") + sender = _FakeAgent(name="alice") + adapter.on_receive(receiver=receiver, message={"content": "hello"}, sender=sender) + + evt = next(e for e in stratix.events if e["event_type"] == "agent.state.change") + assert evt["payload"]["agent"] == "bob" + assert evt["payload"]["from_agent"] == "alice" + + +def test_on_generate_reply_emits_model_invoke() -> None: + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + agent = _FakeAgent(name="alice", llm_config={"model": "gpt-5"}) + reply = type("Reply", (), {"usage": {"prompt_tokens": 10, "completion_tokens": 5}})() + adapter.on_generate_reply(agent=agent, messages=[{"role": "user", "content": "hi"}], reply=reply, latency_ms=42.0) + + evt = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert evt["payload"]["agent"] == "alice" + assert evt["payload"]["model"] == "gpt-5" + assert evt["payload"]["latency_ms"] == 42.0 + assert evt["payload"]["tokens_prompt"] == 10 + + +def test_on_execute_code_emits_tool_call_and_environment() -> None: + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + agent = _FakeAgent(name="alice") + adapter.on_execute_code(agent=agent, code_blocks=[("python", "print(1)")], result="1\n", latency_ms=5.0) + + types = [e["event_type"] for e in stratix.events] + assert "tool.call" in types + assert "tool.environment" in types + tool = next(e for e in stratix.events if e["event_type"] == "tool.call") + assert tool["payload"]["tool_name"] == "code_execution" + assert tool["payload"]["code_blocks_count"] == 1 + + +def test_on_conversation_start_end_emits_input_output() -> None: + stratix = _RecordingStratix() + adapter = AutoGenAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + initiator = _FakeAgent(name="alice") + adapter.on_conversation_start(initiator=initiator, message="start") + adapter.on_conversation_end(final_message="bye", termination_reason="max_rounds") + + types = [e["event_type"] for e in stratix.events] + assert "agent.input" in types + assert "agent.output" in types + + out = next(e for e in stratix.events if e["event_type"] == "agent.output") + assert out["payload"]["termination_reason"] == "max_rounds" + assert out["payload"]["duration_ns"] >= 0 + + +def test_capture_config_gates_l3_model_metadata() -> None: + stratix = _RecordingStratix() + cfg = CaptureConfig(l3_model_metadata=False) + adapter = AutoGenAdapter(stratix=stratix, capture_config=cfg) + adapter.connect() + + agent = _FakeAgent(name="alice", llm_config={"model": "gpt-5"}) + adapter.on_generate_reply(agent=agent, reply="hi") + sender = _FakeAgent(name="alice") + recipient = _FakeAgent(name="bob") + adapter.on_send(sender=sender, message="x", recipient=recipient) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" not in types + # handoff (from on_send) is cross-cutting / always enabled. + assert "agent.handoff" in types + + +def test_instrument_agents_helper() -> None: + """Top-level convenience wraps multiple agents at once.""" + a = _FakeAgent(name="a") + b = _FakeAgent(name="b") + result = instrument_agents(a, b) + assert isinstance(result, list) + assert len(result) == 2 + # Both wrapped. + assert a.send.__name__ == "traced_send" + assert b.send.__name__ == "traced_send" + + +def test_serialize_for_replay() -> None: + adapter = AutoGenAdapter( + stratix=_RecordingStratix(), + capture_config=CaptureConfig.full(), + ) + adapter.connect() + rt = adapter.serialize_for_replay() + assert rt.framework == "autogen" + assert rt.adapter_name == "AutoGenAdapter" + assert "capture_config" in rt.config diff --git a/tests/instrument/adapters/frameworks/test_crewai_adapter.py b/tests/instrument/adapters/frameworks/test_crewai_adapter.py new file mode 100644 index 00000000..7f236694 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_crewai_adapter.py @@ -0,0 +1,204 @@ +"""Unit tests for the CrewAI framework adapter. + +Mocked at the SDK shape level — no real ``crewai`` runtime needed. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.frameworks.crewai import ( + ADAPTER_CLASS, + CrewAIAdapter, + LayerLensCrewCallback, + instrument_crew, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +class _FakeCrew: + def __init__(self, agents: Any = None, process: Any = None) -> None: + self.agents = agents or [] + self.process = process + self.step_callback: Any = None + self.task_callback: Any = None + + +def _make_agent(role: str = "researcher", tools: Any = None, llm: Any = None) -> SimpleNamespace: + return SimpleNamespace( + role=role, + goal="goal", + backstory="back", + verbose=False, + allow_delegation=False, + max_iter=5, + memory=False, + tools=tools, + llm=llm, + ) + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is CrewAIAdapter + + +def test_lifecycle() -> None: + a = CrewAIAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + a.disconnect() + assert a.status == AdapterStatus.DISCONNECTED + + +def test_adapter_info_and_health() -> None: + a = CrewAIAdapter() + a.connect() + info = a.get_adapter_info() + assert info.framework == "crewai" + assert info.name == "CrewAIAdapter" + health = a.health_check() + assert health.framework_name == "crewai" + + +def test_instrument_crew_attaches_callback_and_emits_config() -> None: + stratix = _RecordingStratix() + adapter = CrewAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + crew = _FakeCrew( + agents=[_make_agent(role="researcher"), _make_agent(role="writer")], + process="sequential", + ) + instrumented = adapter.instrument_crew(crew) + + # Callbacks attached. + assert instrumented.step_callback is not None + assert instrumented.task_callback is not None + assert isinstance(instrumented._stratix_callback, LayerLensCrewCallback) + + # Two environment.config events — one per agent role. + configs = [e for e in stratix.events if e["event_type"] == "environment.config"] + assert len(configs) == 2 + roles = {c["payload"]["agent_role"] for c in configs} + assert roles == {"researcher", "writer"} + + +def test_environment_config_idempotent_per_role() -> None: + """Re-instrumenting a crew with same agents should not re-emit configs.""" + stratix = _RecordingStratix() + adapter = CrewAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + crew = _FakeCrew(agents=[_make_agent(role="researcher")]) + adapter.instrument_crew(crew) + adapter.instrument_crew(crew) + + configs = [e for e in stratix.events if e["event_type"] == "environment.config"] + assert len(configs) == 1 + + +def test_on_crew_start_end_emits_input_output() -> None: + stratix = _RecordingStratix() + adapter = CrewAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + adapter.on_crew_start(crew_input="research topic") + adapter.on_crew_end(crew_output="report") + + types = [e["event_type"] for e in stratix.events] + assert "agent.input" in types + assert "agent.output" in types + + out = next(e for e in stratix.events if e["event_type"] == "agent.output") + assert out["payload"]["output"] == "report" + assert out["payload"]["duration_ns"] >= 0 + + +def test_on_task_start_end_emits_code_and_state_change() -> None: + stratix = _RecordingStratix() + adapter = CrewAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + adapter.on_task_start("research", agent_role="researcher", task_order=1) + + # Build a task_output with token_usage to also verify cost.record fires. + task_output = SimpleNamespace( + token_usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + ) + adapter.on_task_end(task_output=task_output, agent_role="researcher", task_order=1) + + types = [e["event_type"] for e in stratix.events] + assert "agent.code" in types + assert "agent.state.change" in types + assert "cost.record" in types + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + assert cost["payload"]["tokens_total"] == 15 + + +def test_on_tool_use_emits_event() -> None: + stratix = _RecordingStratix() + adapter = CrewAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + adapter.on_tool_use("calc", tool_input={"x": 1}, tool_output=2, latency_ms=12.3) + + evt = next(e for e in stratix.events if e["event_type"] == "tool.call") + assert evt["payload"]["tool_name"] == "calc" + assert evt["payload"]["latency_ms"] == 12.3 + + +def test_on_delegation_emits_handoff() -> None: + stratix = _RecordingStratix() + adapter = CrewAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + adapter.on_delegation(from_agent="researcher", to_agent="writer", context="findings") + + evt = next(e for e in stratix.events if e["event_type"] == "agent.handoff") + assert evt["payload"]["from_agent"] == "researcher" + assert evt["payload"]["to_agent"] == "writer" + + +def test_capture_config_gates_l5a_tool_calls() -> None: + stratix = _RecordingStratix() + cfg = CaptureConfig(l5a_tool_calls=False) + adapter = CrewAIAdapter(stratix=stratix, capture_config=cfg) + adapter.connect() + + adapter.on_tool_use("calc", tool_input={"x": 1}, tool_output=2) + adapter.on_delegation(from_agent="a", to_agent="b", context="x") + + types = [e["event_type"] for e in stratix.events] + assert "tool.call" not in types + # handoff is cross-cutting / always enabled. + assert "agent.handoff" in types + + +def test_instrument_crew_helper() -> None: + """Top-level convenience function returns the instrumented crew.""" + crew = _FakeCrew(agents=[_make_agent(role="r1")]) + result = instrument_crew(crew) + # The helper returns the crew itself (with callbacks attached). + assert result is crew + assert result._stratix_callback is not None + + +def test_serialize_for_replay() -> None: + adapter = CrewAIAdapter( + stratix=_RecordingStratix(), + capture_config=CaptureConfig.full(), + ) + adapter.connect() + rt = adapter.serialize_for_replay() + assert rt.framework == "crewai" + assert rt.adapter_name == "CrewAIAdapter" + assert "capture_config" in rt.config diff --git a/tests/instrument/adapters/frameworks/test_langchain_capabilities.py b/tests/instrument/adapters/frameworks/test_langchain_capabilities.py new file mode 100644 index 00000000..48082560 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_langchain_capabilities.py @@ -0,0 +1,45 @@ +"""Capability regression tests for the LangChain framework adapter. + +PR #119 (brand leak + capability declarations) wired REPLAY into the +adapters that lived on its branch but deferred LangChain because it +lives on the orchestration source-port branch (PR #96). This test +file is the regression guard for the closure: REPLAY must be declared +because ``LayerLensCallbackHandler.serialize_for_replay`` returns a +non-stub ``ReplayableTrace``, and STREAMING must NOT be declared +because the adapter does not register an ``on_llm_new_token`` callback +(no per-chunk events flow through the adapter — see callbacks.py). + +Per CLAUDE.md 'no fake claims', a capability is only declared if the +adapter actually implements it. +""" + +from __future__ import annotations + +from layerlens.instrument.adapters._base.adapter import AdapterCapability +from layerlens.instrument.adapters.frameworks.langchain.callbacks import ( + LayerLensCallbackHandler, +) + + +def test_declares_replay_capability() -> None: + handler = LayerLensCallbackHandler() + caps = handler.info().capabilities + assert AdapterCapability.REPLAY in caps + + +def test_does_not_declare_streaming_capability() -> None: + """LangChain adapter has no ``on_llm_new_token`` callback — no + per-chunk events flow through it. STREAMING must stay undeclared + until ``on_llm_new_token`` is wired (and tested) explicitly.""" + handler = LayerLensCallbackHandler() + caps = handler.info().capabilities + assert AdapterCapability.STREAMING not in caps + + +def test_get_adapter_info_matches_info_wrapper() -> None: + """``info()`` (BaseAdapter wrapper) and ``get_adapter_info()`` + (subclass override) must report identical capability lists so the + catalog manifest emitter sees the same answer regardless of which + entrypoint it calls.""" + handler = LayerLensCallbackHandler() + assert handler.info().capabilities == handler.get_adapter_info().capabilities diff --git a/tests/instrument/adapters/frameworks/test_langfuse_adapter.py b/tests/instrument/adapters/frameworks/test_langfuse_adapter.py new file mode 100644 index 00000000..0fc633a1 --- /dev/null +++ b/tests/instrument/adapters/frameworks/test_langfuse_adapter.py @@ -0,0 +1,162 @@ +"""Unit tests for the Langfuse framework adapter. + +Mocked at the SDK shape level — no real Langfuse API calls. + +Unlike runtime-wrapping adapters, the Langfuse adapter is a data +import/export pipeline. Tests focus on: + + * lifecycle (with and without config) + * connect-without-config edge case (adapter still healthy, no client) + * import/export return SyncResult with appropriate error message when + no client is configured + * SyncState tracking semantics + * health_check / get_status structural correctness + * serialize_for_replay returns proper ReplayableTrace +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.frameworks.langfuse import ( + ADAPTER_CLASS, + LangfuseAdapter, +) +from layerlens.instrument.adapters.frameworks.langfuse.config import ( + SyncDirection, + LangfuseConfig, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is LangfuseAdapter + + +def test_lifecycle_no_config() -> None: + """Adapter is usable without a Langfuse config — connects HEALTHY but no client.""" + a = LangfuseAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + assert a.is_connected is True + a.disconnect() + assert a.status == AdapterStatus.DISCONNECTED + + +def test_adapter_info_and_health() -> None: + a = LangfuseAdapter() + a.connect() + info = a.get_adapter_info() + assert info.framework == "langfuse" + assert info.name == "LangfuseAdapter" + health = a.health_check() + assert health.framework_name == "langfuse" + assert "No Langfuse config" in (health.message or "") + + +def test_import_returns_error_result_when_not_connected() -> None: + """Without a Langfuse client, import_traces returns an errored SyncResult.""" + a = LangfuseAdapter() + a.connect() + result = a.import_traces() + assert result.direction == SyncDirection.IMPORT + assert result.errors + assert "not connected" in result.errors[0].lower() + + +def test_export_returns_error_result_when_not_connected() -> None: + """Without a Langfuse client, export_traces returns an errored SyncResult.""" + a = LangfuseAdapter() + a.connect() + result = a.export_traces(events_by_trace={"trace-1": []}) + assert result.direction == SyncDirection.EXPORT + assert result.errors + assert "not connected" in result.errors[0].lower() + + +def test_sync_returns_error_result_when_not_connected() -> None: + """Without a Langfuse client, sync() returns an errored SyncResult.""" + a = LangfuseAdapter() + a.connect() + result = a.sync() + assert result.errors + assert "not connected" in result.errors[0].lower() + + +def test_sync_state_tracking() -> None: + """SyncState records imports/exports and updates cursors.""" + a = LangfuseAdapter() + a.connect() + state = a.sync_state + + from datetime import datetime, timezone + + UTC = timezone.utc # Python 3.11+ has datetime.UTC; alias for 3.9/3.10 compat. + + t0 = datetime(2024, 1, 1, tzinfo=UTC) + t1 = datetime(2024, 1, 2, tzinfo=UTC) + state.record_import("trace-1", t0) + state.record_import("trace-2", t1) + assert "trace-1" in state.imported_trace_ids + assert "trace-2" in state.imported_trace_ids + assert state.last_import_cursor == t1 + + +def test_get_status_structure() -> None: + """get_status returns a complete status dict.""" + a = LangfuseAdapter() + a.connect() + status = a.get_status() + assert "connected" in status + assert "langfuse_healthy" in status + assert "host" in status + assert "imported_traces" in status + assert "exported_traces" in status + assert "quarantined_traces" in status + + +def test_config_property_default_none() -> None: + """When no config provided, config property is None.""" + a = LangfuseAdapter() + a.connect() + assert a.config is None + + +def test_config_property_returns_provided_config() -> None: + """When config provided to constructor, it is exposed via the config property. + + This avoids ``connect(config=cfg)`` which would attempt an HTTP health + check against the (fake) host. The constructor path stores the config + without networking; ``connect()`` then runs without the health check + when ``self._config`` was set on a prior call we skip here. + """ + cfg = LangfuseConfig(public_key="pk-test", secret_key="sk-test", host="https://api/") + a = LangfuseAdapter(config=cfg) + assert a.config is cfg + # Trailing slash is stripped by the validator. + assert a.config.host == "https://api" + + +def test_capture_config_passes_through() -> None: + """The standard capture_config kwarg is accepted and stored.""" + cfg = CaptureConfig.full() + a = LangfuseAdapter(capture_config=cfg) + assert a.capture_config is cfg + + +def test_serialize_for_replay() -> None: + a = LangfuseAdapter(stratix=_RecordingStratix(), capture_config=CaptureConfig.full()) + a.connect() + rt = a.serialize_for_replay() + assert rt.framework == "langfuse" + assert rt.adapter_name == "LangfuseAdapter" + assert "sync_state" in rt.metadata