diff --git a/docs/adapters/providers-anthropic.md b/docs/adapters/providers-anthropic.md new file mode 100644 index 0000000..2ba0b8d --- /dev/null +++ b/docs/adapters/providers-anthropic.md @@ -0,0 +1,70 @@ +# Anthropic provider adapter + +`layerlens.instrument.adapters.providers.anthropic_adapter.AnthropicAdapter` +instruments the Anthropic Python SDK to emit telemetry on every +`messages.create` and `messages.stream` call. + +## Install + +```bash +pip install 'layerlens[providers-anthropic]' +``` + +Pulls `anthropic>=0.30,<1`. + +## Quick start + +```python +from anthropic import Anthropic +from layerlens.instrument.adapters.providers.anthropic_adapter import AnthropicAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="anthropic") +adapter = AnthropicAdapter() +adapter.add_sink(sink) +adapter.connect() + +client = Anthropic() +adapter.connect_client(client) + +response = client.messages.create( + model="claude-haiku-4-5-20251001", + max_tokens=20, + messages=[{"role": "user", "content": "Hello"}], +) + +adapter.disconnect() +sink.close() +``` + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `model.invoke` | L3 | Every `messages.create` (success or failure) and once per stream completion | +| `cost.record` | cross-cutting | Every successful call with token usage | +| `tool.call` | L5a | One per `tool_use` block in the response | +| `policy.violation` | cross-cutting | When the SDK raises (rate limit, invalid input, etc.) | + +The `model.invoke` payload includes Anthropic-specific fields: +- `cache_creation_input_tokens` / `cache_read_input_tokens` (when prompt caching is used) +- `parameters.has_system: true` when a system prompt is supplied +- `parameters.tools_count` when tools are passed +- `reasoning_tokens` (Claude extended thinking) + +## Streaming + +The adapter wraps both `messages.create(stream=True)` and the +`messages.stream()` context manager. A single consolidated `model.invoke` +fires on stream completion, accumulating content from `text_delta` events +and tool input from `input_json_delta` events. + +## Cost calculation + +Pricing comes from the canonical table — Claude models get the 90% cached-token +discount automatically. + +## BYOK + +Same pattern as the OpenAI adapter — pass `api_key` to the `Anthropic()` client. +The platform-side BYOK store ships in atlas-app M1.B. diff --git a/docs/adapters/providers-azure-openai.md b/docs/adapters/providers-azure-openai.md new file mode 100644 index 0000000..e816281 --- /dev/null +++ b/docs/adapters/providers-azure-openai.md @@ -0,0 +1,48 @@ +# Azure OpenAI provider adapter + +`layerlens.instrument.adapters.providers.azure_openai_adapter.AzureOpenAIAdapter` +uses the same `openai` SDK as the OpenAI adapter but captures Azure-specific +metadata (deployment, endpoint, region, api-version) and uses the Azure +pricing table. + +## Install + +```bash +pip install 'layerlens[providers-azure-openai]' +``` + +## Quick start + +```python +from openai import AzureOpenAI +from layerlens.instrument.adapters.providers.azure_openai_adapter import AzureOpenAIAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="azure_openai") +adapter = AzureOpenAIAdapter() +adapter.add_sink(sink) +adapter.connect() + +client = AzureOpenAI( + api_key="...", + api_version="2024-08-01-preview", + azure_endpoint="https://my-resource.openai.azure.com/", +) +adapter.connect_client(client) + +client.chat.completions.create(model="my-deployment", messages=[...]) +``` + +## Azure-specific behavior + +- **Endpoint sanitization**: query strings are stripped from the captured + `azure_endpoint` to prevent token leakage if the URL ever contains an + `api-key` query param. +- **Pricing**: cost calculations use `AZURE_PRICING` (different rates than + OpenAI's public API). +- **api-version**: read from `client._api_version` or the `api-version` key of + `client._custom_query` and surfaced in every `model.invoke`. + +## Events emitted + +Same set as OpenAI: `model.invoke`, `cost.record`, `tool.call`, `policy.violation`. diff --git a/docs/adapters/providers-bedrock.md b/docs/adapters/providers-bedrock.md new file mode 100644 index 0000000..5982f10 --- /dev/null +++ b/docs/adapters/providers-bedrock.md @@ -0,0 +1,64 @@ +# AWS Bedrock provider adapter + +`layerlens.instrument.adapters.providers.bedrock_adapter.AWSBedrockAdapter` +wraps the `bedrock-runtime` boto3 client. Bedrock is a multi-provider +front: Anthropic, Meta, Cohere, Amazon Titan, AI21, and Mistral models all +flow through the same client interface but with different request and +response body shapes. The adapter detects the provider family from +`modelId` and parses tokens, content, and stop reasons accordingly. + +## Install + +```bash +pip install 'layerlens[providers-bedrock]' +``` + +Pulls `boto3>=1.34`. + +## Quick start + +```python +import boto3 +from layerlens.instrument.adapters.providers.bedrock_adapter import AWSBedrockAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="aws_bedrock") +adapter = AWSBedrockAdapter() +adapter.add_sink(sink) +adapter.connect() + +client = boto3.client("bedrock-runtime", region_name="us-east-1") +adapter.connect_client(client) + +# Either invoke_model or converse — both wrapped. +client.converse( + modelId="anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": [{"text": "Hi"}]}], +) +``` + +## Wrapped methods + +- `invoke_model` — body is JSON, parsed per provider family. Response body is + wrapped in a `_RereadableBody` so the caller's downstream `body.read()` + still works. +- `converse` — unified Anthropic-style envelope. Token extraction is uniform. +- `invoke_model_with_response_stream` — emits `model.invoke` immediately with + `streaming=true`; content extraction during stream consumption is deferred + to a future PR. +- `converse_stream` — same. + +## Provider-family token extraction + +| `modelId` prefix | Family | Token fields | +|---|---|---| +| `anthropic.` | anthropic | `usage.input_tokens` / `usage.output_tokens` | +| `meta.` | meta | `prompt_token_count` / `generation_token_count` | +| `cohere.` | cohere | `meta.billed_units.input_tokens` / `output_tokens` | +| `amazon.` | amazon | (no usage in body; tokens come from `Converse` API) | +| `ai21.` | ai21 | (handled via `Converse` API) | +| `mistral.` | mistral | `prompt_tokens` / `completion_tokens` | + +## Cost calculation + +Uses the `BEDROCK_PRICING` table (separate from OpenAI/Azure tables). diff --git a/docs/adapters/providers-cohere.md b/docs/adapters/providers-cohere.md new file mode 100644 index 0000000..ca7cfda --- /dev/null +++ b/docs/adapters/providers-cohere.md @@ -0,0 +1,78 @@ +# Cohere provider adapter + +`layerlens.instrument.adapters.providers.cohere_adapter.CohereAdapter` +instruments the Cohere Python SDK (v5+) for chat (v1 + v2) and embeddings. + +## Install + +```bash +pip install 'layerlens[providers-cohere]' +``` + +Pulls `cohere>=5.0,<6`. + +## Quick start + +```python +import cohere +from layerlens.instrument.adapters.providers.cohere_adapter import CohereAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="cohere") +adapter = CohereAdapter() +adapter.add_sink(sink) +adapter.connect() + +client = cohere.Client() +adapter.connect_client(client) + +# v1 chat (single message + optional preamble) +response = client.chat(model="command-r-plus", message="Hello", preamble="Be concise.") + +# v2 chat (OpenAI-style messages list) +response = client.v2.chat( + model="command-r-plus", + messages=[{"role": "user", "content": "Hello"}], +) +``` + +## What's wrapped + +- `client.chat` (v1) — `message` is normalized to a `user` role; optional `preamble` becomes a `system` message at index 0. +- `client.v2.chat` — already OpenAI-style; messages pass through. +- `client.embed` — `meta.billed_units.input_tokens` populates the cost record. + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `model.invoke` | L3 | Every chat or embed call (success or failure). | +| `cost.record` | cross-cutting | Every successful call with billed units. | +| `tool.call` | L5a | One per tool call in the response (v1: `tool_calls[].name/parameters`; v2: `message.tool_calls[].function.{name,arguments}`). | +| `policy.violation` | cross-cutting | When the SDK raises (rate limit, invalid input, etc.). | + +## Cost calculation + +Pricing is sourced from the canonical `PRICING` table: + +| Model | Input | Output | +|---|---|---| +| command-r-plus | $0.003 | $0.015 | +| command-r | $0.0005 | $0.0015 | +| command-r-plus-08-2024 | $0.0025 | $0.01 | +| command-r-08-2024 | $0.00015 | $0.0006 | +| command | $0.001 | $0.002 | +| command-light | $0.0003 | $0.0006 | + +Cohere-via-Bedrock models use `BEDROCK_PRICING` instead. + +## Streaming + +The current adapter wraps non-streaming `chat` and `chat_stream`-style +calls. If you call `client.chat_stream(...)` directly, the underlying +function is not currently wrapped — open an issue if you need it. + +## BYOK + +Pass `api_key` to `cohere.Client(api_key=...)` as you would normally. +The platform-side BYOK store ships in atlas-app M1.B. diff --git a/docs/adapters/providers-google-vertex.md b/docs/adapters/providers-google-vertex.md new file mode 100644 index 0000000..8fd986c --- /dev/null +++ b/docs/adapters/providers-google-vertex.md @@ -0,0 +1,52 @@ +# Google Vertex AI provider adapter + +`layerlens.instrument.adapters.providers.google_vertex_adapter.GoogleVertexAdapter` +wraps `GenerativeModel.generate_content` from either the +`google.generativeai` or `vertexai.generative_models` SDK. + +## Install + +```bash +pip install 'layerlens[providers-vertex]' +``` + +Pulls `google-cloud-aiplatform>=1.50,<2`. + +## Quick start + +```python +from vertexai.generative_models import GenerativeModel +from layerlens.instrument.adapters.providers.google_vertex_adapter import GoogleVertexAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="google_vertex") +adapter = GoogleVertexAdapter() +adapter.add_sink(sink) +adapter.connect() + +model = GenerativeModel("gemini-1.5-pro") +adapter.connect_client(model) + +response = model.generate_content("Why is the sky blue?") +``` + +## Vertex-specific behavior + +- **`models/` prefix stripping**: `model_name="models/gemini-1.5-pro"` is normalized to + `gemini-1.5-pro` for pricing-table lookup. +- **Function calls**: extracted from `candidates[0].content.parts[].function_call` + and emitted as `tool.call` events with the `args` dict. +- **`thoughts_token_count`**: when the model returns reasoning tokens, they + populate `model.invoke.reasoning_tokens`. +- **`finish_reason`**: enum value name is captured (e.g., `"STOP"`, `"MAX_TOKENS"`). + +## Streaming + +`generate_content(stream=True)` is wrapped — the adapter accumulates +chunk-level usage and emits one consolidated `model.invoke` on stream +completion. Function calls in streamed responses follow the same accumulation +pattern. + +## Cost calculation + +Gemini models get the 75% cached-token discount per the canonical pricing table. diff --git a/docs/adapters/providers-litellm.md b/docs/adapters/providers-litellm.md new file mode 100644 index 0000000..fedb6af --- /dev/null +++ b/docs/adapters/providers-litellm.md @@ -0,0 +1,67 @@ +# LiteLLM provider adapter + +`layerlens.instrument.adapters.providers.litellm_adapter.LiteLLMAdapter` +hooks into LiteLLM's callback system rather than monkey-patching client +methods. This avoids interfering with LiteLLM's own routing, fallback, and +retry behavior. + +## Install + +```bash +pip install 'layerlens[providers-litellm]' +``` + +Pulls `litellm>=1.40,<2`. + +## Quick start + +```python +import litellm +from layerlens.instrument.adapters.providers.litellm_adapter import LiteLLMAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="litellm") +adapter = LiteLLMAdapter() +adapter.add_sink(sink) +adapter.connect() # registers the callback with litellm.callbacks + +# No connect_client needed — the callback is module-global. +litellm.completion( + model="openai/gpt-4o-mini", + messages=[{"role": "user", "content": "Hi"}], +) + +adapter.disconnect() # removes the callback +``` + +## Provider auto-detection + +The adapter parses LiteLLM model strings and routes the `provider` field of +each event to the underlying provider name: + +| Prefix | Provider | +|---|---| +| `openai/` | `openai` | +| `anthropic/` | `anthropic` | +| `azure/` | `azure_openai` | +| `bedrock/` | `aws_bedrock` | +| `vertex_ai/` | `google_vertex` | +| `ollama/` | `ollama` | +| `cohere/` | `cohere` | +| `groq/` | `groq` | +| (no prefix) | inferred from model name (`gpt-`, `claude-`, `gemini-`, ...) | + +Unrecognized models get `provider="unknown"`. + +## Cost calculation + +Cost is sourced in this order: +1. LiteLLM's own `litellm.completion_cost(...)` — if it returns a non-None value, + it's used and the event is tagged `cost_source: "litellm"`. +2. The canonical LayerLens pricing table for the resolved provider. + +## Backward-compat alias + +`STRATIXLiteLLMCallback` is preserved as an alias for `LayerLensLiteLLMCallback` +so users coming from the `ateam` framework codebase don't need to rewrite +imports immediately. The alias will be removed in v2.0. diff --git a/docs/adapters/providers-mistral.md b/docs/adapters/providers-mistral.md new file mode 100644 index 0000000..9a6987a --- /dev/null +++ b/docs/adapters/providers-mistral.md @@ -0,0 +1,65 @@ +# Mistral AI provider adapter + +`layerlens.instrument.adapters.providers.mistral_adapter.MistralAdapter` +instruments the `mistralai` v1 SDK for `chat.complete`, `chat.stream`, +and `embeddings.create`. + +## Install + +```bash +pip install 'layerlens[providers-mistral]' +``` + +Pulls `mistralai>=1.0,<2`. + +## Quick start + +```python +from mistralai import Mistral +from layerlens.instrument.adapters.providers.mistral_adapter import MistralAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="mistral") +adapter = MistralAdapter() +adapter.add_sink(sink) +adapter.connect() + +client = Mistral(api_key="...") +adapter.connect_client(client) + +response = client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "Hello"}], +) +``` + +## What's wrapped + +- `client.chat.complete` — synchronous chat (OpenAI-shape response). +- `client.chat.stream` — streaming wrapper accumulates content + tool-call deltas; emits **one** consolidated `model.invoke` on iterator exhaustion. +- `client.embeddings.create` — embedding telemetry. + +## Events emitted + +Same set as OpenAI: `model.invoke`, `cost.record`, `tool.call`, `policy.violation`. + +The streaming path emits a single `model.invoke` with `metadata.streaming=true` +on completion, not per chunk. + +## Cost calculation + +| Model | Input | Output | +|---|---|---| +| mistral-large / mistral-large-latest | $0.002 | $0.006 | +| mistral-small / mistral-small-latest | $0.0002 | $0.0006 | +| mistral-medium | $0.0027 | $0.0081 | +| open-mistral-7b | $0.00025 | $0.00025 | +| open-mixtral-8x7b | $0.0007 | $0.0007 | +| open-mixtral-8x22b | $0.002 | $0.006 | + +Mistral-via-Bedrock uses `BEDROCK_PRICING`. + +## BYOK + +Pass `api_key` to `Mistral(api_key=...)` as normal. The platform-side +BYOK store ships in atlas-app M1.B. diff --git a/docs/adapters/providers-ollama.md b/docs/adapters/providers-ollama.md new file mode 100644 index 0000000..3e905ce --- /dev/null +++ b/docs/adapters/providers-ollama.md @@ -0,0 +1,50 @@ +# Ollama provider adapter + +`layerlens.instrument.adapters.providers.ollama_adapter.OllamaAdapter` +instruments the Ollama Python SDK for local LLM inference. + +## Install + +```bash +pip install 'layerlens[providers-ollama]' +``` + +Pulls `ollama>=0.2`. + +## Quick start + +```python +import ollama +from layerlens.instrument.adapters.providers.ollama_adapter import OllamaAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="ollama") +adapter = OllamaAdapter(cost_per_second=0.005) # optional infra cost +adapter.add_sink(sink) +adapter.connect() + +# Wrap the ollama module — it acts as a global client. +adapter.connect_client(ollama) + +ollama.chat(model="llama3.1", messages=[{"role": "user", "content": "Hi"}]) +``` + +## Ollama-specific behavior + +- **`api_cost_usd: 0.0`** is always emitted because Ollama runs locally — there's + no API to bill for. +- **Optional `infra_cost_usd`**: if you pass `cost_per_second` to the adapter + constructor, the adapter sums `prompt_eval_duration` + `eval_duration` (both + in nanoseconds) and computes `total_seconds * cost_per_second`. Useful for + attributing GPU rental cost to specific LLM calls. +- **Endpoint capture**: `OLLAMA_HOST` env var (or `http://localhost:11434`) is + recorded in every event so you can identify which Ollama instance handled a + request. +- **Three methods wrapped**: `chat`, `generate`, and `embeddings`. The + `method` field in `model.invoke.metadata` distinguishes them. + +## Token extraction + +Ollama responses (dict or object form) expose `prompt_eval_count` and +`eval_count` — these map to `prompt_tokens` and `completion_tokens` in +`NormalizedTokenUsage`. diff --git a/docs/adapters/providers-openai.md b/docs/adapters/providers-openai.md new file mode 100644 index 0000000..a5f429e --- /dev/null +++ b/docs/adapters/providers-openai.md @@ -0,0 +1,126 @@ +# OpenAI provider adapter + +`layerlens.instrument.adapters.providers.openai_adapter.OpenAIAdapter` instruments +the OpenAI Python SDK to emit telemetry on every chat completion, embedding, or +streaming call. + +## Install + +```bash +pip install 'layerlens[providers-openai]' +``` + +Pulls `openai>=1.30,<2`. + +## Quick start + +```python +from openai import OpenAI +from layerlens.instrument.adapters.providers.openai_adapter import OpenAIAdapter +from layerlens.instrument.transport.sink_http import HttpEventSink + +sink = HttpEventSink(adapter_name="openai") # ships to atlas-app +adapter = OpenAIAdapter() +adapter._event_sinks.append(sink) +adapter.connect() + +client = OpenAI() +adapter.connect_client(client) + +# Every call from now on is instrumented. +response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Hello"}], +) + +adapter.disconnect() +sink.close() +``` + +## Events emitted + +| Event | Layer | When | +|---|---|---| +| `model.invoke` | L3 | Every chat completion or embedding call (success or failure). | +| `cost.record` | cross-cutting | Every successful call with token usage in the response. | +| `tool.call` | L5a | One per tool call returned in a chat response. | +| `policy.violation` | cross-cutting | When the OpenAI SDK raises (rate limit, invalid input, etc.). | + +The `model.invoke` payload includes: + +- `provider`, `model`, `parameters` (temperature, max_tokens, top_p, ...) +- `prompt_tokens`, `completion_tokens`, `total_tokens`, + `cached_tokens` (when present), `reasoning_tokens` (o1/o3) +- `latency_ms`, `response_id`, `system_fingerprint`, `service_tier` +- `messages` (input) and `output_message` — captured only when + `CaptureConfig.capture_content` is True (the default). + +## Streaming + +For streaming responses, the adapter wraps the iterator and accumulates +content + tool-call deltas + final usage. A **single** `model.invoke` is emitted +on stream completion with `metadata.streaming=true`, not one per chunk. + +To get token usage for streamed responses, pass +`stream_options={"include_usage": True}` to `client.chat.completions.create`. + +## Capture config + +```python +from layerlens.instrument.adapters._base import CaptureConfig + +# Production-light: only L1 + protocol discovery + lifecycle. +adapter = OpenAIAdapter(capture_config=CaptureConfig.minimal()) + +# Recommended: L1 + L3 + L4a + L5a + L6. +adapter = OpenAIAdapter(capture_config=CaptureConfig.standard()) + +# Everything (development / debugging). +adapter = OpenAIAdapter(capture_config=CaptureConfig.full()) + +# Hand-rolled: redact prompt/response content but keep tokens + costs. +adapter = OpenAIAdapter( + capture_config=CaptureConfig( + l3_model_metadata=True, + capture_content=False, + ), +) +``` + +## Cost calculation + +Costs are computed from the canonical pricing table in +`layerlens.instrument.adapters.providers._base.pricing.PRICING`. The table +hash is matched against `ateam` in CI to prevent drift. + +If a model is not in the table the `cost.record` event still fires with +`api_cost_usd: null` and `pricing_unavailable: true`. + +## BYOK + +The adapter does NOT manage your OpenAI API key. Pass it to the OpenAI client +as you would normally: + +```python +from openai import OpenAI + +client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) +``` + +The platform-side BYOK store (atlas-app `byok_credentials` table, encrypted in +AWS Secrets Manager) is for orgs that want their key managed centrally. In +that flow the SDK fetches the key from atlas-app at startup and passes it to +the OpenAI client. See `docs/adapters/byok.md` for setup. + +## Restoring originals + +`adapter.disconnect()` restores the original `client.chat.completions.create` +and `client.embeddings.create` methods. After disconnect, the client behaves +exactly as before `connect_client` was called. + +## Circuit breaker + +If `_stratix.emit()` fails 10 times in a row (transport down, server 5xx +storm), the circuit opens and events are silently dropped for 60 s. After +the cooldown a single attempt is made; success resumes normal flow. +This protects the user's program from a flaky telemetry pipeline. diff --git a/pyproject.toml b/pyproject.toml index ae6d1dc..1fda7a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,29 @@ classifiers = [ [project.optional-dependencies] cli = ["click>=8.0.0"] +# --- Instrument layer: LLM provider adapters --- +# Adding any extra below MUST keep the default `pip install layerlens` +# install set unchanged. Verified by `tests/instrument/test_default_install.py`. +providers-openai = ["openai>=1.30,<2"] +providers-anthropic = ["anthropic>=0.30,<1"] +providers-azure-openai = ["openai>=1.30,<2"] +providers-bedrock = ["boto3>=1.34"] +providers-vertex = ["google-cloud-aiplatform>=1.50,<2"] +providers-ollama = ["ollama>=0.2"] +providers-litellm = ["litellm>=1.40,<2"] +providers-cohere = ["cohere>=5.0,<6"] +providers-mistral = ["mistralai>=1.0,<2"] +providers-all = [ + "openai>=1.30,<2", + "anthropic>=0.30,<1", + "boto3>=1.34", + "google-cloud-aiplatform>=1.50,<2", + "ollama>=0.2", + "litellm>=1.40,<2", + "cohere>=5.0,<6", + "mistralai>=1.0,<2", +] + [project.urls] Homepage = "https://github.com/LayerLens/stratix-python" Repository = "https://github.com/LayerLens/stratix-python" diff --git a/samples/instrument/anthropic/__init__.py b/samples/instrument/anthropic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/samples/instrument/anthropic/main.py b/samples/instrument/anthropic/main.py new file mode 100644 index 0000000..c1000bc --- /dev/null +++ b/samples/instrument/anthropic/main.py @@ -0,0 +1,76 @@ +"""Sample: instrument the real Anthropic client with the LayerLens adapter. + +Required environment: + +* ``ANTHROPIC_API_KEY`` — your Anthropic API key. +* ``LAYERLENS_STRATIX_API_KEY`` — your LayerLens API key (optional). +* ``LAYERLENS_STRATIX_BASE_URL`` — atlas-app base URL (optional). + +Run:: + + pip install 'layerlens[providers-anthropic]' + python -m samples.instrument.anthropic.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.providers.anthropic_adapter import AnthropicAdapter + + +def main() -> int: + if not os.environ.get("ANTHROPIC_API_KEY"): + print("ANTHROPIC_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + from anthropic import Anthropic + except ImportError: + print( + "anthropic package not installed. Install with:\n" + " pip install 'layerlens[providers-anthropic]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="anthropic", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = AnthropicAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + client = Anthropic() + adapter.connect_client(client) + + try: + response = client.messages.create( + model="claude-haiku-4-5-20251001", + max_tokens=20, + system="You are concise.", + messages=[{"role": "user", "content": "What is 2 + 2?"}], + ) + text_blocks = [b.text for b in response.content if getattr(b, "type", None) == "text"] + print(f"Response: {' '.join(text_blocks)}") + print( + f"Tokens — input: {response.usage.input_tokens}, " + f"output: {response.usage.output_tokens}" + ) + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/cohere/__init__.py b/samples/instrument/cohere/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/samples/instrument/cohere/main.py b/samples/instrument/cohere/main.py new file mode 100644 index 0000000..ff57777 --- /dev/null +++ b/samples/instrument/cohere/main.py @@ -0,0 +1,72 @@ +"""Sample: instrument the real Cohere client with the LayerLens adapter. + +Required env: ``COHERE_API_KEY``. Optional: ``LAYERLENS_STRATIX_API_KEY``, +``LAYERLENS_STRATIX_BASE_URL``. + +Run:: + + pip install 'layerlens[providers-cohere]' + python -m samples.instrument.cohere.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.providers.cohere_adapter import CohereAdapter + + +def main() -> int: + if not os.environ.get("COHERE_API_KEY"): + print("COHERE_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + import cohere + except ImportError: + print( + "cohere package not installed. Install with:\n" + " pip install 'layerlens[providers-cohere]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="cohere", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = CohereAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + client = cohere.Client() + adapter.connect_client(client) + + try: + response = client.chat( + model="command-r-plus", + message="What is 2 + 2?", + preamble="You are a concise assistant.", + ) + print(f"Response: {response.text}") + billed = getattr(response.meta, "billed_units", None) + if billed is not None: + input_tokens = getattr(billed, "input_tokens", 0) + output_tokens = getattr(billed, "output_tokens", 0) + print(f"Tokens — input: {input_tokens}, output: {output_tokens}") + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/mistral/__init__.py b/samples/instrument/mistral/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/samples/instrument/mistral/main.py b/samples/instrument/mistral/main.py new file mode 100644 index 0000000..45f15c8 --- /dev/null +++ b/samples/instrument/mistral/main.py @@ -0,0 +1,78 @@ +"""Sample: instrument the real Mistral client with the LayerLens adapter. + +Required env: ``MISTRAL_API_KEY``. Optional: ``LAYERLENS_STRATIX_API_KEY``, +``LAYERLENS_STRATIX_BASE_URL``. + +Run:: + + pip install 'layerlens[providers-mistral]' + python -m samples.instrument.mistral.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.providers.mistral_adapter import MistralAdapter + + +def main() -> int: + if not os.environ.get("MISTRAL_API_KEY"): + print("MISTRAL_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + from mistralai import Mistral + except ImportError: + print( + "mistralai package not installed. Install with:\n" + " pip install 'layerlens[providers-mistral]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="mistral", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = MistralAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + client = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) + adapter.connect_client(client) + + try: + response = client.chat.complete( + model="mistral-small-latest", + messages=[ + {"role": "system", "content": "You are a concise assistant."}, + {"role": "user", "content": "What is 2 + 2?"}, + ], + max_tokens=20, + ) + text = response.choices[0].message.content if response.choices else "(empty)" + usage = response.usage + print(f"Response: {text}") + if usage is not None: + print( + f"Tokens — prompt: {usage.prompt_tokens}, " + f"completion: {usage.completion_tokens}, " + f"total: {usage.total_tokens}" + ) + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/samples/instrument/openai/README.md b/samples/instrument/openai/README.md new file mode 100644 index 0000000..dadfae4 --- /dev/null +++ b/samples/instrument/openai/README.md @@ -0,0 +1,62 @@ +# OpenAI adapter sample + +> ⚠ **The platform telemetry endpoint (`/api/v1/telemetry/spans`) is not +> live yet.** It lands in atlas-app M1.B (integrations + telemetry-ingest +> Go packages). Until then, the sink will log +> `layerlens.sink.batch_dropped` after the third consecutive failure and +> events are not persisted. The adapter and SDK side are fully functional +> — you can run this sample today against any HTTP server that accepts +> JSON POSTs, including a local ngrok tunnel or the harness in +> `tests/instrument/test_sink_http_e2e.py`. + +This sample demonstrates the LayerLens OpenAI provider adapter wrapping a real +OpenAI client. Every chat completion or embedding call is intercepted and +turned into telemetry events shipped to atlas-app. + +## What you'll see + +Running `python -m samples.instrument.openai.main` produces three events for a +single chat completion: + +- `model.invoke` (L3) — the request and response, with parameters, tokens, and + latency. +- `cost.record` (cross-cutting) — the API cost in USD computed from the + pricing table for the requested model. +- `tool.call` (L5a, only if the model returned function calls) — one event per + tool call. + +The events are batched and POSTed to +`$LAYERLENS_STRATIX_BASE_URL/telemetry/spans` with `X-API-Key` auth. If no key +is present the sink runs anonymously and the platform may reject the events +depending on org policy. + +## Install + +```bash +pip install 'layerlens[providers-openai]' +``` + +The `providers-openai` extra installs `openai>=1.30,<2`. The default +`pip install layerlens` does NOT pull `openai` — that's the lazy-import +guarantee tested by `tests/instrument/test_lazy_imports.py`. + +## Run + +```bash +export OPENAI_API_KEY=sk-... +export LAYERLENS_STRATIX_API_KEY=ll-... # optional +python -m samples.instrument.openai.main +``` + +## Verify telemetry landed + +After the sample exits, check the LayerLens dashboard adapter health page — +the `openai` adapter row should show a recent `last_seen` timestamp and a +non-zero invocation count. + +## Streaming + +To run the sample against a streaming response, modify the `client.chat.completions.create` +call to add `stream=True, stream_options={"include_usage": True}` and iterate +the stream. The adapter's stream wrapper emits a single consolidated +`model.invoke` on stream completion, not one per chunk. diff --git a/samples/instrument/openai/__init__.py b/samples/instrument/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/samples/instrument/openai/main.py b/samples/instrument/openai/main.py new file mode 100644 index 0000000..5c60fb5 --- /dev/null +++ b/samples/instrument/openai/main.py @@ -0,0 +1,87 @@ +"""Sample: instrument the real OpenAI client with the LayerLens adapter. + +Runs a single chat completion through ``OpenAIAdapter`` with an +``HttpEventSink`` pointed at atlas-app. Every event the adapter emits +(``model.invoke``, ``cost.record``, optional ``tool.call``) is shipped +to the platform's telemetry ingest endpoint. + +Required environment: + +* ``OPENAI_API_KEY`` — your OpenAI API key. +* ``LAYERLENS_STRATIX_API_KEY`` — your LayerLens API key (optional; + defaults to anonymous if unset). +* ``LAYERLENS_STRATIX_BASE_URL`` — atlas-app base URL (optional; + defaults to ``https://api.layerlens.ai/api/v1``). + +Run:: + + pip install 'layerlens[providers-openai]' + python -m samples.instrument.openai.main +""" + +from __future__ import annotations + +import os +import sys + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.providers.openai_adapter import OpenAIAdapter + + +def main() -> int: + if not os.environ.get("OPENAI_API_KEY"): + print("OPENAI_API_KEY is not set; cannot run sample.", file=sys.stderr) + return 2 + + try: + from openai import OpenAI + except ImportError: + print( + "openai package not installed. Install with:\n" + " pip install 'layerlens[providers-openai]'", + file=sys.stderr, + ) + return 2 + + sink = HttpEventSink( + adapter_name="openai", + path="/telemetry/spans", + max_batch=10, + flush_interval_s=1.0, + ) + + adapter = OpenAIAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + + client = OpenAI() + adapter.connect_client(client) + + try: + response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a concise assistant."}, + {"role": "user", "content": "What is 2 + 2?"}, + ], + max_tokens=20, + ) + choice = response.choices[0].message.content if response.choices else "(empty)" + usage = response.usage + print(f"Response: {choice}") + if usage is not None: + print( + f"Tokens — prompt: {usage.prompt_tokens}, completion: " + f"{usage.completion_tokens}, total: {usage.total_tokens}" + ) + finally: + sink.close() + adapter.disconnect() + + print("Telemetry shipped. Check the LayerLens dashboard adapter health page.") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/layerlens/instrument/adapters/providers/__init__.py b/src/layerlens/instrument/adapters/providers/__init__.py new file mode 100644 index 0000000..30cbebe --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/__init__.py @@ -0,0 +1,23 @@ +"""LLM provider adapters for the LayerLens Instrument layer. + +Each provider adapter wraps a vendor SDK client to intercept API calls +and emit ``model.invoke``, ``cost.record``, ``tool.call``, and +``policy.violation`` events through the LayerLens telemetry pipeline. + +Adapters available: + +* ``openai_adapter`` — OpenAI Python SDK (``openai >= 1.30``) +* ``anthropic_adapter`` — Anthropic Python SDK (``anthropic >= 0.30``) +* ``azure_openai_adapter`` — Azure OpenAI (``openai >= 1.30``) +* ``bedrock_adapter`` — AWS Bedrock (``boto3``) +* ``google_vertex_adapter`` — Google Vertex AI (``google-cloud-aiplatform``) +* ``ollama_adapter`` — Ollama (``ollama``) +* ``litellm_adapter`` — LiteLLM proxy (``litellm``) +* ``cohere_adapter`` — Cohere (``cohere`` >= 5) +* ``mistral_adapter`` — Mistral AI (``mistralai`` >= 1) + +Importing this package does NOT import any vendor SDK; modules are +loaded on demand via :class:`AdapterRegistry`. +""" + +from __future__ import annotations diff --git a/src/layerlens/instrument/adapters/providers/_base/__init__.py b/src/layerlens/instrument/adapters/providers/_base/__init__.py new file mode 100644 index 0000000..84aa7da --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_base/__init__.py @@ -0,0 +1,21 @@ +"""Shared base layer for LLM provider adapters.""" + +from __future__ import annotations + +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.pricing import ( + PRICING, + AZURE_PRICING, + BEDROCK_PRICING, + calculate_cost, +) +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +__all__ = [ + "AZURE_PRICING", + "BEDROCK_PRICING", + "LLMProviderAdapter", + "NormalizedTokenUsage", + "PRICING", + "calculate_cost", +] diff --git a/src/layerlens/instrument/adapters/providers/_base/pricing.py b/src/layerlens/instrument/adapters/providers/_base/pricing.py new file mode 100644 index 0000000..a193aee --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_base/pricing.py @@ -0,0 +1,147 @@ +"""LLM Model Pricing. + +Maintains pricing tables (per-1K-token rates) for all supported models +and provides cost calculation with cached-token adjustments. + +Ported verbatim from ``ateam/stratix/sdk/python/adapters/llm_providers/pricing.py``. +The pricing JSON is the canonical platform-wide source-of-truth and is +hash-checked between ``ateam`` and ``stratix-python`` in CI to prevent +drift. +""" + +from __future__ import annotations + +from typing import Dict, Optional + +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage + +# --------------------------------------------------------------------------- +# Pricing tables (per-1K-token rates, USD) +# --------------------------------------------------------------------------- + +PRICING: Dict[str, Dict[str, float]] = { + # OpenAI + "gpt-4o": {"input": 0.0025, "output": 0.0100}, + "gpt-4o-mini": {"input": 0.00015, "output": 0.0006}, + "gpt-4o-2024-11-20": {"input": 0.0025, "output": 0.0100}, + "gpt-4.1": {"input": 0.002, "output": 0.008}, + "gpt-4.1-mini": {"input": 0.0004, "output": 0.0016}, + "gpt-4.1-nano": {"input": 0.0001, "output": 0.0004}, + "gpt-4-turbo": {"input": 0.01, "output": 0.03}, + "gpt-4": {"input": 0.03, "output": 0.06}, + "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, + "o1": {"input": 0.015, "output": 0.060}, + "o1-mini": {"input": 0.003, "output": 0.012}, + "o3": {"input": 0.010, "output": 0.040}, + "o3-mini": {"input": 0.0011, "output": 0.0044}, + "o4-mini": {"input": 0.0011, "output": 0.0044}, + # Anthropic + "claude-sonnet-4-5-20250929": {"input": 0.003, "output": 0.015}, + "claude-opus-4-20250115": {"input": 0.015, "output": 0.075}, + "claude-opus-4-6": {"input": 0.015, "output": 0.075}, + "claude-haiku-4-5-20251001": {"input": 0.0008, "output": 0.004}, + "claude-haiku-3-5-20241022": {"input": 0.0008, "output": 0.004}, + "claude-3-5-sonnet-20241022": {"input": 0.003, "output": 0.015}, + "claude-3-opus-20240229": {"input": 0.015, "output": 0.075}, + "claude-3-haiku-20240307": {"input": 0.00025, "output": 0.00125}, + # Google + "gemini-2.5-pro": {"input": 0.00125, "output": 0.01}, + "gemini-2.5-flash": {"input": 0.000075, "output": 0.0003}, + "gemini-2.0-flash": {"input": 0.0001, "output": 0.0004}, + "gemini-1.5-pro": {"input": 0.00125, "output": 0.005}, + "gemini-1.5-flash": {"input": 0.000075, "output": 0.0003}, + # Meta (Ollama / Bedrock) + "llama-3.3-70b": {"input": 0.00099, "output": 0.00099}, + "llama-3.1-70b": {"input": 0.00099, "output": 0.00099}, + "llama-3.1-8b": {"input": 0.00022, "output": 0.00022}, + # Mistral (direct API; Bedrock has its own table) + "mistral-large": {"input": 0.002, "output": 0.006}, + "mistral-large-latest": {"input": 0.002, "output": 0.006}, + "mistral-small": {"input": 0.0002, "output": 0.0006}, + "mistral-small-latest": {"input": 0.0002, "output": 0.0006}, + "mistral-medium": {"input": 0.0027, "output": 0.0081}, + "open-mistral-7b": {"input": 0.00025, "output": 0.00025}, + "open-mixtral-8x7b": {"input": 0.0007, "output": 0.0007}, + "open-mixtral-8x22b": {"input": 0.002, "output": 0.006}, + # Cohere (direct API; Bedrock-routed Cohere uses BEDROCK_PRICING) + "command-r-plus": {"input": 0.003, "output": 0.015}, + "command-r": {"input": 0.0005, "output": 0.0015}, + "command-r-plus-08-2024": {"input": 0.0025, "output": 0.01}, + "command-r-08-2024": {"input": 0.00015, "output": 0.0006}, + "command-light": {"input": 0.0003, "output": 0.0006}, + "command": {"input": 0.001, "output": 0.002}, +} + +AZURE_PRICING: Dict[str, Dict[str, float]] = { + "gpt-4o": {"input": 0.00275, "output": 0.011}, + "gpt-4o-mini": {"input": 0.000165, "output": 0.00066}, + "gpt-4-turbo": {"input": 0.011, "output": 0.033}, + "gpt-4": {"input": 0.033, "output": 0.066}, + "gpt-35-turbo": {"input": 0.00055, "output": 0.00165}, +} + +BEDROCK_PRICING: Dict[str, Dict[str, float]] = { + "anthropic.claude-3-5-sonnet-20241022-v2:0": {"input": 0.003, "output": 0.015}, + "anthropic.claude-3-opus-20240229-v1:0": {"input": 0.015, "output": 0.075}, + "anthropic.claude-3-haiku-20240307-v1:0": {"input": 0.00025, "output": 0.00125}, + "meta.llama3-1-70b-instruct-v1:0": {"input": 0.00099, "output": 0.00099}, + "meta.llama3-1-8b-instruct-v1:0": {"input": 0.00022, "output": 0.00022}, + "cohere.command-r-plus-v1:0": {"input": 0.003, "output": 0.015}, + "cohere.command-r-v1:0": {"input": 0.0005, "output": 0.0015}, +} + + +def _cached_token_discount(model: str) -> float: + """Determine the cached-token rate as a fraction of input price. + + Different providers offer different cache discounts: + + * Anthropic — 90% discount (pay 10% of input rate). + * Google — 75% discount (pay 25% of input rate). + * OpenAI and others — 50% discount (pay 50% of input rate). + """ + lower = model.lower() + if lower.startswith("claude"): + return 0.1 + if lower.startswith("gemini"): + return 0.25 + return 0.5 + + +def calculate_cost( + model: str, + usage: NormalizedTokenUsage, + pricing_table: Optional[Dict[str, Dict[str, float]]] = None, +) -> Optional[float]: + """Calculate the API cost in USD for a model invocation. + + Args: + model: Model name (e.g., ``"gpt-4o"``, ``"claude-sonnet-4-5-20250929"``). + usage: Normalized token usage from the provider response. + pricing_table: Override pricing table (for Azure / Bedrock). + Defaults to :data:`PRICING`. + + Returns: + Cost in USD, or ``None`` if the model is not in the pricing table. + """ + table = pricing_table or PRICING + rates = table.get(model) + if rates is None: + return None + + input_rate = rates.get("input", 0.0) + output_rate = rates.get("output", 0.0) + + prompt_tokens = usage.prompt_tokens + cached = usage.cached_tokens or 0 + + non_cached = max(prompt_tokens - cached, 0) + cached_rate = input_rate * _cached_token_discount(model) + + cost = ( + (non_cached * input_rate / 1000) + + (cached * cached_rate / 1000) + + (usage.completion_tokens * output_rate / 1000) + ) + + return round(cost, 8) diff --git a/src/layerlens/instrument/adapters/providers/_base/provider.py b/src/layerlens/instrument/adapters/providers/_base/provider.py new file mode 100644 index 0000000..4eb0a40 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_base/provider.py @@ -0,0 +1,403 @@ +"""LLM Provider Base Adapter. + +Abstract intermediate class for all LLM provider adapters. Extends +:class:`BaseAdapter` with provider-specific emit helpers for +``model.invoke``, ``cost.record``, ``tool.call``, and +``policy.violation`` events. + +Supports W3C Trace Context propagation (``traceparent`` / +``tracestate``) for correlating spans across adapter boundaries. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/base_provider.py``. +""" + +from __future__ import annotations + +import time +import uuid +import logging +from abc import abstractmethod +from typing import Any, Dict, List, Optional + +from layerlens._compat.pydantic import model_dump +from layerlens.instrument.adapters._base.adapter import ( + AdapterInfo, + BaseAdapter, + AdapterHealth, + AdapterStatus, + ReplayableTrace, + AdapterCapability, +) +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.pricing import calculate_cost + +# W3C Trace Context header names. +_TRACEPARENT_HEADER = "traceparent" +_TRACESTATE_HEADER = "tracestate" + +logger = logging.getLogger(__name__) + + +class LLMProviderAdapter(BaseAdapter): + """Abstract base class for all LLM provider adapters. + + Provides concrete implementations for: + + * Event emission helpers (:meth:`_emit_model_invoke`, + :meth:`_emit_cost_record`, :meth:`_emit_tool_calls`, + :meth:`_emit_provider_error`). + * Lifecycle methods (:meth:`health_check`, + :meth:`get_adapter_info`, :meth:`serialize_for_replay`). + * Client reference management (``_client``, ``_originals``). + + Subclasses must implement: + + * :meth:`connect` — import framework, set HEALTHY. + * :meth:`disconnect` — restore originals, set DISCONNECTED. + * :meth:`connect_client` — wrap the provider client. + """ + + adapter_type: str = "llm_provider" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._client: Any = None + self._originals: Dict[str, Any] = {} + self._framework_version: Optional[str] = None + + # --- Abstract methods subclasses must implement --- + + @abstractmethod + def connect_client(self, client: Any) -> Any: + """Wrap or monkey-patch the provider client to intercept API calls. + + Args: + client: The provider SDK client instance. + + Returns: + The wrapped client (same object, modified in-place). + """ + + # --- Concrete lifecycle methods --- + + def connect(self) -> None: + """Verify framework availability and mark as connected.""" + self._framework_version = self._detect_framework_version() + self._connected = True + self._status = AdapterStatus.HEALTHY + + def disconnect(self) -> None: + """Restore all original methods and disconnect.""" + self._restore_originals() + self._client = None + self._originals.clear() + self._connected = False + self._status = AdapterStatus.DISCONNECTED + + def _restore_originals(self) -> None: + """Restore original methods on the client. Override for custom logic.""" + + def health_check(self) -> AdapterHealth: + return AdapterHealth( + status=self._status, + framework_name=self.FRAMEWORK, + framework_version=self._framework_version, + adapter_version=self.VERSION, + error_count=self._error_count, + circuit_open=self._circuit_open, + ) + + def get_adapter_info(self) -> AdapterInfo: + return AdapterInfo( + name=type(self).__name__, + version=self.VERSION, + framework=self.FRAMEWORK, + framework_version=self._framework_version, + capabilities=[ + AdapterCapability.TRACE_MODELS, + AdapterCapability.TRACE_TOOLS, + ], + description=f"LayerLens adapter for {self.FRAMEWORK} LLM provider", + ) + + def serialize_for_replay(self) -> ReplayableTrace: + return ReplayableTrace( + adapter_name=type(self).__name__, + framework=self.FRAMEWORK, + trace_id=str(uuid.uuid4()), + events=list(self._trace_events), + state_snapshots=[], + config={"capture_config": model_dump(self._capture_config)}, + ) + + @staticmethod + def _detect_framework_version() -> Optional[str]: + """Override in subclasses to detect SDK version.""" + return None + + # --- W3C Trace Context Propagation --- + + def _inject_trace_context( + self, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """Inject W3C ``traceparent`` / ``tracestate`` headers for outbound requests. + + If OpenTelemetry is available, uses the OTel propagator. Otherwise + generates a minimal ``traceparent`` from the current trace / span + IDs. + + Args: + headers: Existing headers dict to inject into (mutated in place). + + Returns: + Headers dict with ``traceparent`` (and optionally ``tracestate``) added. + """ + if headers is None: + headers = {} + + try: + from opentelemetry.propagate import inject + + inject(headers) + except ImportError: + trace_id = getattr(self, "_current_trace_id", None) + span_id = getattr(self, "_current_span_id", None) + if trace_id and span_id: + headers[_TRACEPARENT_HEADER] = f"00-{trace_id}-{span_id}-01" + + return headers + + def _extract_trace_context( + self, + headers: Dict[str, str], + ) -> Dict[str, str]: + """Extract W3C ``traceparent`` / ``tracestate`` from inbound headers. + + Args: + headers: Inbound headers dict. + + Returns: + Dict with ``trace_id``, ``parent_span_id``, ``trace_flags``, + and optionally ``tracestate``. + """ + result: Dict[str, str] = {} + + traceparent = headers.get(_TRACEPARENT_HEADER, "") + if traceparent: + parts = traceparent.split("-") + if len(parts) >= 4: + result["trace_id"] = parts[1] + result["parent_span_id"] = parts[2] + result["trace_flags"] = parts[3] + + tracestate = headers.get(_TRACESTATE_HEADER, "") + if tracestate: + result["tracestate"] = tracestate + + return result + + # --- Event emission helpers --- + + def _emit_model_invoke( + self, + provider: str, + model: Optional[str], + parameters: Optional[Dict[str, Any]] = None, + usage: Optional[NormalizedTokenUsage] = None, + latency_ms: Optional[float] = None, + error: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + input_messages: Optional[List[Dict[str, str]]] = None, + output_message: Optional[Dict[str, str]] = None, + ) -> None: + """Emit a ``model.invoke`` (L3) event.""" + payload: Dict[str, Any] = { + "provider": provider, + "model": model, + "timestamp_ns": time.time_ns(), + } + if parameters: + payload["parameters"] = parameters + if usage: + payload["prompt_tokens"] = usage.prompt_tokens + payload["completion_tokens"] = usage.completion_tokens + payload["total_tokens"] = usage.total_tokens + if usage.cached_tokens is not None: + payload["cached_tokens"] = usage.cached_tokens + if usage.reasoning_tokens is not None: + payload["reasoning_tokens"] = usage.reasoning_tokens + if latency_ms is not None: + payload["latency_ms"] = latency_ms + if error: + payload["error"] = error + if metadata: + for k, v in metadata.items(): + if k not in payload: + payload[k] = v + if self._capture_config.capture_content: + if input_messages: + payload["messages"] = input_messages + if output_message: + payload["output_message"] = output_message + + self.emit_dict_event("model.invoke", payload) + + @staticmethod + def _normalize_messages( + raw_messages: Any, + system: Any = None, + ) -> Optional[List[Dict[str, str]]]: + """Normalize provider-specific message formats to ``[{role, content}]``. + + Args: + raw_messages: Messages from the provider SDK kwargs (list of + dicts, list of objects, or ``None``). + system: Separate system prompt (e.g. Anthropic's ``system`` + kwarg). May be a string or a list of content blocks. + + Returns: + Normalized list, or ``None`` if no messages were found. + """ + if not raw_messages and not system: + return None + + messages: List[Dict[str, str]] = [] + + if system: + if isinstance(system, str): + messages.append({"role": "system", "content": system[:10_000]}) + elif isinstance(system, list): + parts: List[str] = [] + for block in system: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict) and "text" in block: + parts.append(str(block["text"])) + if parts: + messages.append({"role": "system", "content": "\n".join(parts)[:10_000]}) + + if raw_messages: + for msg in raw_messages: + role = "" + content = "" + if isinstance(msg, dict): + role = str(msg.get("role", "")) + raw_content = msg.get("content", "") + if isinstance(raw_content, str): + content = raw_content + elif isinstance(raw_content, list): + parts2: List[str] = [] + for part in raw_content: + if isinstance(part, str): + parts2.append(part) + elif isinstance(part, dict): + text = part.get("text") or part.get("content", "") + if text: + parts2.append(str(text)) + content = "\n".join(parts2) + else: + content = str(raw_content) if raw_content else "" + elif hasattr(msg, "role") and hasattr(msg, "content"): + role = str(getattr(msg, "role", "")) + raw_content = getattr(msg, "content", "") + if isinstance(raw_content, str): + content = raw_content + elif isinstance(raw_content, list): + parts3: List[str] = [] + for part in raw_content: + if isinstance(part, str): + parts3.append(part) + elif hasattr(part, "text"): + parts3.append(str(part.text)) + elif isinstance(part, dict) and "text" in part: + parts3.append(str(part["text"])) + content = "\n".join(parts3) + else: + content = str(raw_content) if raw_content else "" + else: + continue + + if role: + messages.append({"role": role, "content": content[:10_000]}) + + return messages if messages else None + + def _emit_cost_record( + self, + model: Optional[str], + usage: Optional[NormalizedTokenUsage], + provider: Optional[str] = None, + pricing_table: Optional[Dict[str, Dict[str, float]]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Emit a ``cost.record`` (cross-cutting) event.""" + payload: Dict[str, Any] = { + "provider": provider or self.FRAMEWORK, + "model": model, + } + + if usage: + payload["prompt_tokens"] = usage.prompt_tokens + payload["completion_tokens"] = usage.completion_tokens + payload["total_tokens"] = usage.total_tokens + + cost = calculate_cost(model or "", usage, pricing_table) + if cost is not None: + payload["api_cost_usd"] = cost + else: + payload["api_cost_usd"] = None + payload["pricing_unavailable"] = True + + if metadata: + for k, v in metadata.items(): + if k not in payload: + payload[k] = v + + self.emit_dict_event("cost.record", payload) + + def _emit_tool_calls( + self, + tool_calls: List[Dict[str, Any]], + parent_model: Optional[str] = None, + ) -> None: + """Emit ``tool.call`` (L5a) events for function / tool calls in a response.""" + for tc in tool_calls: + payload: Dict[str, Any] = { + "tool_name": tc.get("name", "unknown"), + "tool_input": tc.get("arguments") or tc.get("input"), + "provider": self.FRAMEWORK, + } + if parent_model: + payload["model"] = parent_model + if "id" in tc: + payload["tool_call_id"] = tc["id"] + + self.emit_dict_event("tool.call", payload) + + def _emit_provider_error( + self, + provider: str, + error: str, + model: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Emit ``policy.violation`` (cross-cutting) for provider errors.""" + payload: Dict[str, Any] = { + "provider": provider, + "error": error, + "violation_type": "safety", + } + if model: + payload["model"] = model + if metadata: + for k, v in metadata.items(): + if k not in payload: + payload[k] = v + + self.emit_dict_event("policy.violation", payload) diff --git a/src/layerlens/instrument/adapters/providers/_base/tokens.py b/src/layerlens/instrument/adapters/providers/_base/tokens.py new file mode 100644 index 0000000..69c7c7c --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_base/tokens.py @@ -0,0 +1,80 @@ +"""Normalized Token Usage. + +Provides a common data structure for token usage across all LLM +providers. Each provider adapter constructs this from its own response +format. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/token_usage.py``. + +The source uses Pydantic v2's ``model_validator`` and ``model_copy``, +which do not exist in Pydantic v1. The ``stratix-python`` SDK pins +``pydantic>=1.9.0, <3``, so this port avoids both v2-only features: + +* The auto-total behavior is implemented as :meth:`with_auto_total` + classmethod and :meth:`compute_total` instance method that construct + fresh instances rather than relying on a validator hook. +* Callers in this codebase always pass an explicit ``total_tokens``, + so the auto-compute is purely a defensive convenience for external + callers. +""" + +from __future__ import annotations + +from typing import Optional + +from layerlens._compat.pydantic import Field, BaseModel + + +class NormalizedTokenUsage(BaseModel): + """Normalized token usage across all LLM providers.""" + + prompt_tokens: int = Field(default=0, description="Input tokens (prompt, system, context)") + completion_tokens: int = Field(default=0, description="Output tokens (response, generation)") + total_tokens: int = Field(default=0, description="prompt_tokens + completion_tokens") + cached_tokens: Optional[int] = Field( + default=None, + description="Cached prompt tokens (OpenAI cached, Anthropic cache_read)", + ) + reasoning_tokens: Optional[int] = Field( + default=None, + description="Reasoning tokens (o1/o3 reasoning, Claude extended thinking)", + ) + + @classmethod + def with_auto_total( + cls, + prompt_tokens: int = 0, + completion_tokens: int = 0, + total_tokens: int = 0, + cached_tokens: Optional[int] = None, + reasoning_tokens: Optional[int] = None, + ) -> "NormalizedTokenUsage": + """Construct a usage record, auto-computing ``total_tokens`` when zero. + + Use this constructor when the provider response does not include + an explicit total. Callers that already have a total should + instantiate :class:`NormalizedTokenUsage` directly. + """ + if total_tokens == 0 and (prompt_tokens or completion_tokens): + total_tokens = prompt_tokens + completion_tokens + return cls( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + cached_tokens=cached_tokens, + reasoning_tokens=reasoning_tokens, + ) + + def compute_total(self) -> "NormalizedTokenUsage": + """Return a fresh instance with ``total_tokens`` computed from prompt + completion. + + Constructs a new instance rather than calling Pydantic v2's + ``model_copy(update=...)`` so the code runs under v1 and v2. + """ + return type(self)( + prompt_tokens=self.prompt_tokens, + completion_tokens=self.completion_tokens, + total_tokens=self.prompt_tokens + self.completion_tokens, + cached_tokens=self.cached_tokens, + reasoning_tokens=self.reasoning_tokens, + ) diff --git a/src/layerlens/instrument/adapters/providers/anthropic_adapter.py b/src/layerlens/instrument/adapters/providers/anthropic_adapter.py new file mode 100644 index 0000000..b5fa604 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/anthropic_adapter.py @@ -0,0 +1,482 @@ +"""Anthropic LLM Provider Adapter. + +Wraps the Anthropic Python SDK client to intercept message completions +and streaming calls. Emits ``model.invoke``, ``cost.record``, +``tool.call``, and ``policy.violation`` events. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/anthropic_adapter.py``. +""" + +from __future__ import annotations + +import time +import logging +from typing import Any, Dict, List, Iterator, Optional + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + +_CAPTURE_PARAMS = frozenset( + { + "model", + "max_tokens", + "temperature", + "top_p", + "top_k", + "tool_choice", + } +) + + +class AnthropicAdapter(LLMProviderAdapter): + """LayerLens adapter for the Anthropic Python SDK. + + Wraps ``client.messages.create`` and ``client.messages.stream`` to + emit ``model.invoke``, ``cost.record``, and ``tool.call`` events. + + Usage:: + + from anthropic import Anthropic + from layerlens.instrument.adapters.providers.anthropic_adapter import AnthropicAdapter + + adapter = AnthropicAdapter() + adapter.connect() + + client = Anthropic() + adapter.connect_client(client) + """ + + FRAMEWORK = "anthropic" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + + def connect_client(self, client: Any) -> Any: + """Wrap Anthropic client methods with tracing.""" + self._client = client + + if hasattr(client, "messages"): + original_create = client.messages.create + self._originals["messages.create"] = original_create + client.messages.create = self._wrap_messages_create(original_create) + + if hasattr(client.messages, "stream"): + original_stream = client.messages.stream + self._originals["messages.stream"] = original_stream + client.messages.stream = self._wrap_messages_stream(original_stream) + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + if "messages.create" in self._originals: + try: + self._client.messages.create = self._originals["messages.create"] + except Exception: + logger.warning("Could not restore messages.create") + if "messages.stream" in self._originals: + try: + self._client.messages.stream = self._originals["messages.stream"] + except Exception: + logger.warning("Could not restore messages.stream") + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import anthropic + + version = getattr(anthropic, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + # --- Wrapping methods --- + + def _wrap_messages_create(self, original: Any) -> Any: + adapter = self + + def traced_create(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + if "system" in kwargs: + params["has_system"] = True + tools = kwargs.get("tools") + if tools: + params["tools_count"] = len(tools) + is_stream = kwargs.get("stream", False) + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages( + kwargs.get("messages"), + system=kwargs.get("system"), + ) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="anthropic", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error("anthropic", str(exc), model=model) + except Exception: + logger.warning("Error emitting Anthropic error event", exc_info=True) + raise + + if is_stream: + return adapter._wrap_stream_response( + response, model, params, start_ns, input_messages + ) + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + output_message = adapter._extract_output_message(response) + + metadata: Dict[str, Any] = {} + stop_reason = getattr(response, "stop_reason", None) + if stop_reason is not None: + metadata["finish_reason"] = stop_reason + resp_id = getattr(response, "id", None) + if resp_id is not None: + metadata["response_id"] = resp_id + resp_model = getattr(response, "model", None) + if resp_model is not None: + metadata["response_model"] = resp_model + resp_usage = getattr(response, "usage", None) + if resp_usage is not None: + cache_create = getattr(resp_usage, "cache_creation_input_tokens", None) + if cache_create is not None: + metadata["cache_creation_input_tokens"] = cache_create + cache_read = getattr(resp_usage, "cache_read_input_tokens", None) + if cache_read is not None: + metadata["cache_read_input_tokens"] = cache_read + + adapter._emit_model_invoke( + provider="anthropic", + model=model, + parameters=params, + usage=usage, + latency_ms=elapsed_ms, + input_messages=input_messages, + output_message=output_message, + metadata=metadata if metadata else None, + ) + adapter._emit_cost_record(model=model, usage=usage, provider="anthropic") + + tool_calls = adapter._extract_tool_use(response) + if tool_calls: + adapter._emit_tool_calls(tool_calls, parent_model=model) + except Exception: + logger.warning("Error emitting Anthropic trace events", exc_info=True) + + return response + + traced_create._layerlens_original = original # type: ignore[attr-defined] + return traced_create + + def _wrap_messages_stream(self, original: Any) -> Any: + """Wrap the ``messages.stream`` context manager.""" + adapter = self + + def traced_stream(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + if "system" in kwargs: + params["has_system"] = True + tools = kwargs.get("tools") + if tools: + params["tools_count"] = len(tools) + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages( + kwargs.get("messages"), + system=kwargs.get("system"), + ) + + try: + stream_ctx = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="anthropic", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error("anthropic", str(exc), model=model) + except Exception: + logger.warning("Error emitting Anthropic stream error", exc_info=True) + raise + + return _TracedStreamManager( + adapter, stream_ctx, model, params, start_ns, input_messages + ) + + traced_stream._layerlens_original = original # type: ignore[attr-defined] + return traced_stream + + def _wrap_stream_response( + self, + stream: Any, + model: Optional[str], + params: Dict[str, Any], + start_ns: int, + input_messages: Optional[List[Dict[str, str]]] = None, + ) -> Any: + """Wrap a streaming response (from ``stream=True``) iterator.""" + adapter = self + accumulated_tool_calls: List[Dict[str, Any]] = [] + accumulated_content: List[str] = [] + final_usage: Optional[NormalizedTokenUsage] = None + stream_finish_reason: Optional[str] = None + stream_response_id: Optional[str] = None + stream_response_model: Optional[str] = None + + class TracedStream: + def __init__(self, inner: Any) -> None: + self._inner = inner + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + try: + event = next(self._inner) + except StopIteration: + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + output_msg: Optional[Dict[str, str]] = None + if accumulated_content: + output_msg = { + "role": "assistant", + "content": "".join(accumulated_content)[:10_000], + } + stream_meta: Dict[str, Any] = {"streaming": True} + if stream_finish_reason is not None: + stream_meta["finish_reason"] = stream_finish_reason + if stream_response_id is not None: + stream_meta["response_id"] = stream_response_id + if stream_response_model is not None: + stream_meta["response_model"] = stream_response_model + adapter._emit_model_invoke( + provider="anthropic", + model=model, + parameters=params, + usage=final_usage, + latency_ms=elapsed_ms, + metadata=stream_meta, + input_messages=input_messages, + output_message=output_msg, + ) + if final_usage: + adapter._emit_cost_record( + model=model, + usage=final_usage, + provider="anthropic", + ) + if accumulated_tool_calls: + adapter._emit_tool_calls( + accumulated_tool_calls, parent_model=model + ) + except Exception: + logger.warning( + "Error emitting Anthropic stream events", exc_info=True + ) + raise + + try: + _process_stream_event(event) + except Exception: + logger.debug("Error processing Anthropic stream event", exc_info=True) + return event + + def __enter__(self) -> Any: + return self + + def __exit__(self, *args: Any) -> Any: + if hasattr(self._inner, "__exit__"): + return self._inner.__exit__(*args) + return None + + def close(self) -> None: + if hasattr(self._inner, "close"): + self._inner.close() + + def _process_stream_event(event: Any) -> None: + nonlocal final_usage, stream_finish_reason, stream_response_id, stream_response_model + event_type = getattr(event, "type", None) + if event_type == "content_block_delta": + delta = getattr(event, "delta", None) + if delta and getattr(delta, "type", None) == "text_delta": + text = getattr(delta, "text", "") + if text: + accumulated_content.append(text) + if event_type == "message_delta": + stop_reason = getattr(event, "delta", None) + if stop_reason is not None: + sr = getattr(stop_reason, "stop_reason", None) + if sr is not None: + stream_finish_reason = sr + usage_data = getattr(event, "usage", None) + if usage_data: + output = getattr(usage_data, "output_tokens", 0) or 0 + prior_prompt = final_usage.prompt_tokens if final_usage else 0 + final_usage = NormalizedTokenUsage( + prompt_tokens=prior_prompt, + completion_tokens=output, + total_tokens=prior_prompt + output, + ) + elif event_type == "message_start": + msg = getattr(event, "message", None) + if msg: + msg_id = getattr(msg, "id", None) + if msg_id is not None: + stream_response_id = msg_id + msg_model = getattr(msg, "model", None) + if msg_model is not None: + stream_response_model = msg_model + usage_data = getattr(msg, "usage", None) + if usage_data: + final_usage = adapter._extract_usage_from_obj(usage_data) + elif event_type == "content_block_start": + block = getattr(event, "content_block", None) + if block and getattr(block, "type", None) == "tool_use": + accumulated_tool_calls.append( + { + "name": getattr(block, "name", "unknown"), + "input": {}, + "id": getattr(block, "id", None), + "_json_parts": [], + } + ) + elif event_type == "content_block_delta": + delta = getattr(event, "delta", None) + if delta and getattr(delta, "type", None) == "input_json_delta": + json_str = getattr(delta, "partial_json", "") + if accumulated_tool_calls and json_str: + accumulated_tool_calls[-1]["_json_parts"].append(json_str) + elif event_type == "content_block_stop": + if accumulated_tool_calls and accumulated_tool_calls[-1].get("_json_parts"): + import json as _json + + try: + full_json = "".join(accumulated_tool_calls[-1].pop("_json_parts")) + accumulated_tool_calls[-1]["input"] = _json.loads(full_json) + except Exception: + accumulated_tool_calls[-1].pop("_json_parts", None) + + return TracedStream(stream) + + # --- Token extraction --- + + def _extract_usage(self, response: Any) -> Optional[NormalizedTokenUsage]: + usage = getattr(response, "usage", None) + if not usage: + return None + return self._extract_usage_from_obj(usage) + + @staticmethod + def _extract_usage_from_obj(usage: Any) -> NormalizedTokenUsage: + input_tokens = getattr(usage, "input_tokens", 0) or 0 + output_tokens = getattr(usage, "output_tokens", 0) or 0 + + cached = getattr(usage, "cache_read_input_tokens", None) + reasoning = getattr(usage, "thinking_tokens", None) + + return NormalizedTokenUsage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + cached_tokens=cached, + reasoning_tokens=reasoning, + ) + + @staticmethod + def _extract_output_message(response: Any) -> Optional[Dict[str, str]]: + """Extract the assistant output message from an Anthropic response.""" + try: + content = getattr(response, "content", None) or [] + parts: List[str] = [] + for block in content: + if getattr(block, "type", None) == "text": + parts.append(getattr(block, "text", "")) + if parts: + return {"role": "assistant", "content": "\n".join(parts)[:10_000]} + except Exception: + logger.debug("Error extracting Anthropic output message", exc_info=True) + return None + + @staticmethod + def _extract_tool_use(response: Any) -> List[Dict[str, Any]]: + """Extract ``tool_use`` blocks from an Anthropic response.""" + tool_calls: List[Dict[str, Any]] = [] + try: + content = getattr(response, "content", None) or [] + for block in content: + if getattr(block, "type", None) == "tool_use": + tool_calls.append( + { + "name": getattr(block, "name", "unknown"), + "input": getattr(block, "input", {}), + "id": getattr(block, "id", None), + } + ) + except Exception: + logger.debug("Error extracting Anthropic tool_use blocks", exc_info=True) + return tool_calls + + +class _TracedStreamManager: + """Wraps the Anthropic ``messages.stream()`` context manager.""" + + def __init__( + self, + adapter: AnthropicAdapter, + inner: Any, + model: Optional[str], + params: Dict[str, Any], + start_ns: int, + input_messages: Optional[List[Dict[str, str]]] = None, + ) -> None: + self._adapter = adapter + self._inner = inner + self._model = model + self._params = params + self._start_ns = start_ns + self._input_messages = input_messages + + def __enter__(self) -> Any: + stream = self._inner.__enter__() + return self._adapter._wrap_stream_response( + stream, + self._model, + self._params, + self._start_ns, + self._input_messages, + ) + + def __exit__(self, *args: Any) -> Any: + return self._inner.__exit__(*args) + + +# Registry lazy-loading convention. +ADAPTER_CLASS = AnthropicAdapter diff --git a/src/layerlens/instrument/adapters/providers/azure_openai_adapter.py b/src/layerlens/instrument/adapters/providers/azure_openai_adapter.py new file mode 100644 index 0000000..c0bfc96 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/azure_openai_adapter.py @@ -0,0 +1,252 @@ +"""Azure OpenAI LLM Provider Adapter. + +Same wrapping as OpenAI (same SDK) with additional capture of +``deployment_name``, ``azure_endpoint``, ``api_version``, and region. +Uses the Azure-specific pricing table. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/azure_openai_adapter.py``. +""" + +from __future__ import annotations + +import time +import logging +from typing import Any, Dict, Optional +from urllib.parse import urlparse, urlunparse + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.pricing import AZURE_PRICING +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter +from layerlens.instrument.adapters.providers.openai_adapter import OpenAIAdapter + +logger = logging.getLogger(__name__) + +_CAPTURE_PARAMS = frozenset( + { + "model", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "response_format", + "tool_choice", + } +) + + +class AzureOpenAIAdapter(LLMProviderAdapter): + """LayerLens adapter for Azure OpenAI Service. + + Uses the same ``openai`` SDK but captures Azure-specific metadata + (deployment, endpoint, region) and uses Azure pricing. + """ + + FRAMEWORK = "azure_openai" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._azure_metadata: Dict[str, Any] = {} + + @staticmethod + def _sanitize_endpoint(url: Any) -> Optional[str]: + """Strip query parameters from the endpoint URL to prevent token leakage.""" + if url is None: + return None + url_str = str(url) + parsed = urlparse(url_str) + return urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", "")) + + def connect_client(self, client: Any) -> Any: + """Wrap Azure OpenAI client methods with tracing.""" + self._client = client + + raw_endpoint = getattr(client, "_base_url", None) or getattr(client, "base_url", None) + self._azure_metadata = { + "azure_endpoint": self._sanitize_endpoint(raw_endpoint), + "api_version": getattr(client, "_api_version", None), + } + custom_query = getattr(client, "_custom_query", None) + if custom_query and isinstance(custom_query, dict): + self._azure_metadata["api_version"] = custom_query.get( + "api-version", self._azure_metadata.get("api_version") + ) + + if hasattr(client, "chat") and hasattr(client.chat, "completions"): + original_create = client.chat.completions.create + self._originals["chat.completions.create"] = original_create + client.chat.completions.create = self._wrap_chat_create(original_create) + + if hasattr(client, "embeddings"): + original_embed = client.embeddings.create + self._originals["embeddings.create"] = original_embed + client.embeddings.create = self._wrap_embeddings_create(original_embed) + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + if "chat.completions.create" in self._originals: + try: + self._client.chat.completions.create = self._originals[ + "chat.completions.create" + ] + except Exception: + logger.warning("Could not restore chat.completions.create") + if "embeddings.create" in self._originals: + try: + self._client.embeddings.create = self._originals["embeddings.create"] + except Exception: + logger.warning("Could not restore embeddings.create") + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import openai + + version = getattr(openai, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + def _wrap_chat_create(self, original: Any) -> Any: + adapter = self + + def traced_create(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="azure_openai", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + metadata=adapter._azure_metadata, + input_messages=input_messages, + ) + adapter._emit_provider_error("azure_openai", str(exc), model=model) + except Exception: + logger.warning("Error emitting Azure error event", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + resp_usage = getattr(response, "usage", None) + usage = ( + OpenAIAdapter._extract_usage_from_obj(resp_usage) if resp_usage else None + ) + output_message = OpenAIAdapter._extract_output_message(response) + + merged_metadata: Dict[str, Any] = dict(adapter._azure_metadata) + choices = getattr(response, "choices", None) or [] + if choices: + fr = getattr(choices[0], "finish_reason", None) + if fr is not None: + merged_metadata["finish_reason"] = fr + resp_id = getattr(response, "id", None) + if resp_id is not None: + merged_metadata["response_id"] = resp_id + resp_model = getattr(response, "model", None) + if resp_model is not None: + merged_metadata["response_model"] = resp_model + sys_fp = getattr(response, "system_fingerprint", None) + if sys_fp is not None: + merged_metadata["system_fingerprint"] = sys_fp + + adapter._emit_model_invoke( + provider="azure_openai", + model=model, + parameters=params, + usage=usage, + latency_ms=elapsed_ms, + metadata=merged_metadata, + input_messages=input_messages, + output_message=output_message, + ) + adapter._emit_cost_record( + model=model, + usage=usage, + provider="azure_openai", + pricing_table=AZURE_PRICING, + ) + + tool_calls = OpenAIAdapter._extract_tool_calls(response) + if tool_calls: + adapter._emit_tool_calls(tool_calls, parent_model=model) + except Exception: + logger.warning("Error emitting Azure trace events", exc_info=True) + + return response + + traced_create._layerlens_original = original # type: ignore[attr-defined] + return traced_create + + def _wrap_embeddings_create(self, original: Any) -> Any: + adapter = self + + def traced_embed(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + start_ns = time.time_ns() + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="azure_openai", + model=model, + latency_ms=elapsed_ms, + error=str(exc), + metadata={**adapter._azure_metadata, "request_type": "embedding"}, + ) + except Exception: + logger.warning("Error emitting Azure embedding error", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + resp_usage = getattr(response, "usage", None) + usage = ( + OpenAIAdapter._extract_usage_from_obj(resp_usage) if resp_usage else None + ) + + adapter._emit_model_invoke( + provider="azure_openai", + model=model, + usage=usage, + latency_ms=elapsed_ms, + metadata={**adapter._azure_metadata, "request_type": "embedding"}, + ) + adapter._emit_cost_record( + model=model, + usage=usage, + provider="azure_openai", + pricing_table=AZURE_PRICING, + ) + except Exception: + logger.warning("Error emitting Azure embedding events", exc_info=True) + + return response + + traced_embed._layerlens_original = original # type: ignore[attr-defined] + return traced_embed + + +# Registry lazy-loading convention. +ADAPTER_CLASS = AzureOpenAIAdapter diff --git a/src/layerlens/instrument/adapters/providers/bedrock_adapter.py b/src/layerlens/instrument/adapters/providers/bedrock_adapter.py new file mode 100644 index 0000000..cdd85b9 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/bedrock_adapter.py @@ -0,0 +1,592 @@ +"""AWS Bedrock LLM Provider Adapter. + +Wraps ``invoke_model``, ``invoke_model_with_response_stream``, +``converse``, and ``converse_stream``. Parses ``modelId`` to detect the +provider family for token extraction. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/bedrock_adapter.py``. +""" + +from __future__ import annotations + +import json +import time +import logging +from typing import Any, Dict, List, Iterator, Optional + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.pricing import BEDROCK_PRICING +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + + +def _detect_provider_family(model_id: str) -> str: + """Detect the provider family from a Bedrock ``modelId``.""" + if not model_id: + return "unknown" + lower = model_id.lower() + if lower.startswith("anthropic."): + return "anthropic" + if lower.startswith("meta."): + return "meta" + if lower.startswith("cohere."): + return "cohere" + if lower.startswith("amazon."): + return "amazon" + if lower.startswith("ai21."): + return "ai21" + if lower.startswith("mistral."): + return "mistral" + return "unknown" + + +class AWSBedrockAdapter(LLMProviderAdapter): + """LayerLens adapter for AWS Bedrock (``bedrock-runtime``). + + Wraps ``invoke_model``, ``invoke_model_with_response_stream``, + ``converse``, and ``converse_stream``. Parses ``modelId`` for + provider-specific token extraction. + """ + + FRAMEWORK = "aws_bedrock" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + + def connect_client(self, client: Any) -> Any: + """Wrap a Bedrock runtime client with tracing.""" + self._client = client + + if hasattr(client, "invoke_model"): + original = client.invoke_model + self._originals["invoke_model"] = original + client.invoke_model = self._wrap_invoke_model(original) + + if hasattr(client, "converse"): + original = client.converse + self._originals["converse"] = original + client.converse = self._wrap_converse(original) + + if hasattr(client, "invoke_model_with_response_stream"): + original = client.invoke_model_with_response_stream + self._originals["invoke_model_with_response_stream"] = original + client.invoke_model_with_response_stream = self._wrap_invoke_stream(original) + + if hasattr(client, "converse_stream"): + original = client.converse_stream + self._originals["converse_stream"] = original + client.converse_stream = self._wrap_converse_stream(original) + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + for method_name, original in self._originals.items(): + try: + setattr(self._client, method_name, original) + except Exception: + logger.warning("Could not restore %s", method_name) + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import boto3 # type: ignore[import-untyped,unused-ignore] + + version = getattr(boto3, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + def _wrap_invoke_model(self, original: Any) -> Any: + adapter = self + + def traced_invoke(*args: Any, **kwargs: Any) -> Any: + model_id = kwargs.get("modelId", "") + start_ns = time.time_ns() + + input_messages = adapter._extract_invoke_messages(kwargs, model_id) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"method": "invoke_model"}, + input_messages=input_messages, + ) + adapter._emit_provider_error("aws_bedrock", str(exc), model=model_id) + except Exception: + logger.warning("Error emitting Bedrock error event", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + body = response.get("body") + body_data: Dict[str, Any] = {} + if body and hasattr(body, "read"): + body_bytes = body.read() + try: + body_data = json.loads(body_bytes) + except (json.JSONDecodeError, TypeError, ValueError): + body_data = {} + response["body"] = _RereadableBody(body_bytes) + + family = _detect_provider_family(model_id) + usage = adapter._extract_invoke_usage(body_data, family) + output_message = adapter._extract_invoke_output(body_data, family) + + invoke_metadata: Dict[str, Any] = { + "method": "invoke_model", + "provider_family": family, + } + if family == "anthropic": + sr = body_data.get("stop_reason") + if sr is not None: + invoke_metadata["finish_reason"] = sr + rid = body_data.get("id") + if rid is not None: + invoke_metadata["response_id"] = rid + elif family in ("meta", "mistral"): + sr = body_data.get("stop_reason") + if sr is not None: + invoke_metadata["finish_reason"] = sr + elif family == "cohere": + gens = body_data.get("generations", []) + if gens and isinstance(gens, list): + sr = gens[0].get("finish_reason") + if sr is not None: + invoke_metadata["finish_reason"] = sr + else: + sr = body_data.get("stop_reason") or body_data.get("finish_reason") + if sr is not None: + invoke_metadata["finish_reason"] = sr + + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + usage=usage, + latency_ms=elapsed_ms, + metadata=invoke_metadata, + input_messages=input_messages, + output_message=output_message, + ) + adapter._emit_cost_record( + model=model_id, + usage=usage, + provider="aws_bedrock", + pricing_table=BEDROCK_PRICING, + ) + except Exception: + logger.warning("Error emitting Bedrock invoke events", exc_info=True) + + return response + + traced_invoke._layerlens_original = original # type: ignore[attr-defined] + return traced_invoke + + def _wrap_converse(self, original: Any) -> Any: + adapter = self + + def traced_converse(*args: Any, **kwargs: Any) -> Any: + model_id = kwargs.get("modelId", "") + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"method": "converse"}, + input_messages=input_messages, + ) + adapter._emit_provider_error("aws_bedrock", str(exc), model=model_id) + except Exception: + logger.warning("Error emitting Bedrock converse error", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_converse_usage(response) + output_message = adapter._extract_converse_output(response) + + converse_metadata: Dict[str, Any] = {"method": "converse"} + stop_reason = response.get("stopReason") + if stop_reason is not None: + converse_metadata["finish_reason"] = stop_reason + resp_meta = response.get("ResponseMetadata", {}) + request_id = resp_meta.get("RequestId") + if request_id is not None: + converse_metadata["response_id"] = request_id + + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + usage=usage, + latency_ms=elapsed_ms, + metadata=converse_metadata, + input_messages=input_messages, + output_message=output_message, + ) + adapter._emit_cost_record( + model=model_id, + usage=usage, + provider="aws_bedrock", + pricing_table=BEDROCK_PRICING, + ) + except Exception: + logger.warning("Error emitting Bedrock converse events", exc_info=True) + + return response + + traced_converse._layerlens_original = original # type: ignore[attr-defined] + return traced_converse + + def _wrap_invoke_stream(self, original: Any) -> Any: + """Wrap ``invoke_model_with_response_stream``. + + ``output_message`` is intentionally not extracted here because + the response is a stream — content is not available until the + caller fully consumes the iterator. + """ + adapter = self + + def traced_invoke_stream(*args: Any, **kwargs: Any) -> Any: + model_id = kwargs.get("modelId", "") + start_ns = time.time_ns() + + input_messages = adapter._extract_invoke_messages(kwargs, model_id) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"method": "invoke_model_with_response_stream"}, + input_messages=input_messages, + ) + except Exception: + logger.warning("Error emitting Bedrock stream error", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + latency_ms=elapsed_ms, + metadata={ + "method": "invoke_model_with_response_stream", + "streaming": True, + }, + input_messages=input_messages, + ) + except Exception: + logger.warning("Error emitting Bedrock stream events", exc_info=True) + + return response + + traced_invoke_stream._layerlens_original = original # type: ignore[attr-defined] + return traced_invoke_stream + + def _wrap_converse_stream(self, original: Any) -> Any: + """Wrap ``converse_stream``.""" + adapter = self + + def traced_converse_stream(*args: Any, **kwargs: Any) -> Any: + model_id = kwargs.get("modelId", "") + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"method": "converse_stream"}, + input_messages=input_messages, + ) + except Exception: + logger.warning( + "Error emitting Bedrock converse_stream error", exc_info=True + ) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + adapter._emit_model_invoke( + provider="aws_bedrock", + model=model_id, + latency_ms=elapsed_ms, + metadata={"method": "converse_stream", "streaming": True}, + input_messages=input_messages, + ) + except Exception: + logger.warning( + "Error emitting Bedrock converse_stream events", exc_info=True + ) + + return response + + traced_converse_stream._layerlens_original = original # type: ignore[attr-defined] + return traced_converse_stream + + # --- Message extraction --- + + @staticmethod + def _extract_invoke_messages( + kwargs: Dict[str, Any], + model_id: str, + ) -> Optional[List[Dict[str, str]]]: + """Extract messages from ``invoke_model`` body based on provider family.""" + try: + body = kwargs.get("body") + if not body: + return None + if isinstance(body, (str, bytes)): + body_data = json.loads(body) + elif isinstance(body, dict): + body_data = body + else: + return None + + family = _detect_provider_family(model_id) + messages: List[Dict[str, str]] = [] + + if family == "anthropic": + system = body_data.get("system", "") + if system: + messages.append({"role": "system", "content": str(system)[:10_000]}) + for msg in body_data.get("messages", []): + if isinstance(msg, dict) and "role" in msg: + content = msg.get("content", "") + if isinstance(content, list): + parts = [ + str(p.get("text", "")) + for p in content + if isinstance(p, dict) and "text" in p + ] + content = "\n".join(parts) + messages.append( + {"role": str(msg["role"]), "content": str(content)[:10_000]} + ) + elif family in ("meta", "mistral"): + prompt = body_data.get("prompt", "") + if prompt: + messages.append({"role": "user", "content": str(prompt)[:10_000]}) + else: + prompt = body_data.get("prompt") or body_data.get("inputText", "") + if prompt: + messages.append({"role": "user", "content": str(prompt)[:10_000]}) + + return messages if messages else None + except Exception: + logger.debug("Error extracting Bedrock invoke messages", exc_info=True) + return None + + # --- Output extraction --- + + @staticmethod + def _extract_invoke_output( + body_data: Dict[str, Any], + family: str, + ) -> Optional[Dict[str, str]]: + """Extract the output message from an ``invoke_model`` response body.""" + try: + if not body_data: + return None + + content = "" + if family == "anthropic": + content_blocks = body_data.get("content", []) + if content_blocks and isinstance(content_blocks, list): + parts = [] + for block in content_blocks: + if isinstance(block, dict) and "text" in block: + parts.append(str(block["text"])) + content = "\n".join(parts) + elif family in ("meta", "mistral"): + content = str(body_data.get("generation", "")) + elif family == "cohere": + generations = body_data.get("generations", []) + if generations and isinstance(generations, list): + content = str(generations[0].get("text", "")) + elif family == "amazon": + results = body_data.get("results", []) + if results and isinstance(results, list): + content = str(results[0].get("outputText", "")) + else: + content = str( + body_data.get("generation", "") + or body_data.get("completion", "") + or body_data.get("outputText", "") + ) + + if content: + return {"role": "assistant", "content": content[:10_000]} + return None + except Exception: + logger.debug("Error extracting Bedrock invoke output", exc_info=True) + return None + + @staticmethod + def _extract_converse_output(response: Dict[str, Any]) -> Optional[Dict[str, str]]: + """Extract the output message from a Converse API response.""" + try: + output = response.get("output", {}) + message = output.get("message", {}) + if not message: + return None + content_blocks = message.get("content", []) + if not content_blocks: + return None + parts: List[str] = [] + for block in content_blocks: + if isinstance(block, dict) and "text" in block: + parts.append(str(block["text"])) + if parts: + return {"role": "assistant", "content": "\n".join(parts)[:10_000]} + return None + except Exception: + logger.debug("Error extracting Bedrock converse output", exc_info=True) + return None + + # --- Token extraction --- + + @staticmethod + def _extract_invoke_usage( + body_data: Dict[str, Any], + family: str, + ) -> Optional[NormalizedTokenUsage]: + """Extract tokens from an ``invoke_model`` response body.""" + if not body_data: + return None + + if family == "anthropic": + usage = body_data.get("usage", {}) + input_t = usage.get("input_tokens", 0) + output_t = usage.get("output_tokens", 0) + return NormalizedTokenUsage( + prompt_tokens=input_t, + completion_tokens=output_t, + total_tokens=input_t + output_t, + ) + + if family == "meta": + prompt = body_data.get("prompt_token_count", 0) + completion = body_data.get("generation_token_count", 0) + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + + if family == "cohere": + meta = body_data.get("meta", {}) + tokens = meta.get("billed_units", {}) + prompt = tokens.get("input_tokens", 0) + completion = tokens.get("output_tokens", 0) + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + + prompt = body_data.get("inputTokenCount", 0) or body_data.get("prompt_tokens", 0) + completion = body_data.get("outputTokenCount", 0) or body_data.get( + "completion_tokens", 0 + ) + if prompt or completion: + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + + return None + + @staticmethod + def _extract_converse_usage( + response: Dict[str, Any], + ) -> Optional[NormalizedTokenUsage]: + """Extract tokens from a Converse API response.""" + usage = response.get("usage", {}) + if not usage: + return None + prompt = usage.get("inputTokens", 0) + completion = usage.get("outputTokens", 0) + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + + +class _RereadableBody: + """Allows the Bedrock response body to be re-read after we consume it. + + Implements the subset of botocore's ``StreamingBody`` interface that + callers typically use after ``invoke_model``. + """ + + def __init__(self, data: bytes) -> None: + self._data = data + self._pos = 0 + + def read(self, amt: Optional[int] = None) -> bytes: + if amt is None: + self._pos = 0 + return self._data + result = self._data[self._pos : self._pos + amt] + self._pos += amt + return result + + def iter_chunks(self, chunk_size: int = 1024) -> Iterator[bytes]: + for i in range(0, len(self._data), chunk_size): + yield self._data[i : i + chunk_size] + + def iter_lines(self) -> Iterator[bytes]: + for line in self._data.split(b"\n"): + if line: + yield line + + def close(self) -> None: + pass + + @property + def content_length(self) -> int: + return len(self._data) + + +# Registry lazy-loading convention. +ADAPTER_CLASS = AWSBedrockAdapter diff --git a/src/layerlens/instrument/adapters/providers/cohere_adapter.py b/src/layerlens/instrument/adapters/providers/cohere_adapter.py new file mode 100644 index 0000000..16f265d --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/cohere_adapter.py @@ -0,0 +1,408 @@ +"""Cohere LLM Provider Adapter. + +Wraps the Cohere Python SDK (``cohere`` >= 5.x) to intercept ``chat`` +and ``embed`` calls. Emits ``model.invoke``, ``cost.record``, +``tool.call``, and ``policy.violation`` events. + +This adapter is **fresh-built**, not a port — Cohere did not have an +adapter in ``ateam`` source as of 2026-04-25. It follows the same +contract as the OpenAI / Anthropic adapters: + +* Wraps ``client.chat`` (Cohere v1) and ``client.v2.chat`` (Cohere v2) + with method substitution. +* Wraps ``client.embed`` for embedding telemetry. +* Honors :class:`CaptureConfig` for layer gating. +* Restores originals on :meth:`disconnect`. + +Cohere's pricing tier is reused from the canonical +:data:`PRICING` table; Cohere-on-Bedrock uses :data:`BEDROCK_PRICING`. +For models not in either table the ``cost.record`` event sets +``api_cost_usd`` to ``None`` and ``pricing_unavailable`` to ``True``. +""" + +from __future__ import annotations + +import time +import logging +from typing import Any, Dict, List, Optional + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + +# Parameters captured from request kwargs when present. +_CAPTURE_PARAMS = frozenset( + { + "model", + "temperature", + "max_tokens", + "p", + "k", + "top_p", + "top_k", + "frequency_penalty", + "presence_penalty", + "response_format", + "tool_choice", + } +) + + +class CohereAdapter(LLMProviderAdapter): + """LayerLens adapter for the Cohere Python SDK. + + Usage:: + + import cohere + from layerlens.instrument.adapters.providers.cohere_adapter import CohereAdapter + + adapter = CohereAdapter() + adapter.connect() + + client = cohere.Client(api_key=os.environ["COHERE_API_KEY"]) + adapter.connect_client(client) + + client.chat(model="command-r-plus", message="Hello") + """ + + FRAMEWORK = "cohere" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + + def connect_client(self, client: Any) -> Any: + """Wrap Cohere v1 (``client.chat``) and v2 (``client.v2.chat``) endpoints.""" + self._client = client + + # v1 chat (callable on the client directly). + if hasattr(client, "chat") and callable(client.chat): + original_chat = client.chat + self._originals["chat"] = original_chat + client.chat = self._wrap_chat(original_chat, version="v1") + + # v2 chat (Cohere SDK 5.x exposes ``client.v2.chat``). + v2 = getattr(client, "v2", None) + if v2 is not None and hasattr(v2, "chat") and callable(v2.chat): + original_v2_chat = v2.chat + self._originals["v2.chat"] = original_v2_chat + v2.chat = self._wrap_chat(original_v2_chat, version="v2") + + # Embed. + if hasattr(client, "embed") and callable(client.embed): + original_embed = client.embed + self._originals["embed"] = original_embed + client.embed = self._wrap_embed(original_embed) + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + if "chat" in self._originals: + try: + self._client.chat = self._originals["chat"] + except Exception: + logger.warning("Could not restore chat") + if "v2.chat" in self._originals: + try: + self._client.v2.chat = self._originals["v2.chat"] + except Exception: + logger.warning("Could not restore v2.chat") + if "embed" in self._originals: + try: + self._client.embed = self._originals["embed"] + except Exception: + logger.warning("Could not restore embed") + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import cohere # type: ignore[import-not-found,unused-ignore] + + version = getattr(cohere, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + # --- Wrapping --- + + def _wrap_chat(self, original: Any, *, version: str) -> Any: + adapter = self + + def traced_chat(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + params["api_version"] = version + start_ns = time.time_ns() + + # v1 uses ``message`` (single string), v2 uses ``messages`` (list). + input_messages: Optional[List[Dict[str, str]]] = None + if version == "v1": + msg = kwargs.get("message") + if msg: + input_messages = [{"role": "user", "content": str(msg)[:10_000]}] + preamble = kwargs.get("preamble") + if preamble: + if input_messages is None: + input_messages = [] + input_messages.insert( + 0, + {"role": "system", "content": str(preamble)[:10_000]}, + ) + else: + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="cohere", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error("cohere", str(exc), model=model) + except Exception: + logger.warning("Error emitting Cohere error event", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response, version) + output_message = adapter._extract_output_message(response, version) + + metadata: Dict[str, Any] = {} + resp_id = getattr(response, "id", None) or getattr( + response, "generation_id", None + ) + if resp_id is not None: + metadata["response_id"] = resp_id + finish_reason = getattr(response, "finish_reason", None) + if finish_reason is not None: + metadata["finish_reason"] = str(finish_reason) + + adapter._emit_model_invoke( + provider="cohere", + model=model, + parameters=params, + usage=usage, + latency_ms=elapsed_ms, + input_messages=input_messages, + output_message=output_message, + metadata=metadata if metadata else None, + ) + adapter._emit_cost_record( + model=model, + usage=usage, + provider="cohere", + ) + + tool_calls = adapter._extract_tool_calls(response, version) + if tool_calls: + adapter._emit_tool_calls(tool_calls, parent_model=model) + except Exception: + logger.warning("Error emitting Cohere trace events", exc_info=True) + + return response + + traced_chat._layerlens_original = original # type: ignore[attr-defined] + return traced_chat + + def _wrap_embed(self, original: Any) -> Any: + adapter = self + + def traced_embed(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + start_ns = time.time_ns() + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="cohere", + model=model, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"request_type": "embedding"}, + ) + except Exception: + logger.warning("Error emitting Cohere embed error", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + # Cohere embed responses use ``meta.billed_units.input_tokens``. + meta = getattr(response, "meta", None) + billed = getattr(meta, "billed_units", None) if meta else None + if billed is None and isinstance(meta, dict): + billed = meta.get("billed_units") + input_tokens = 0 + if billed is not None: + input_tokens = ( + getattr(billed, "input_tokens", 0) + if not isinstance(billed, dict) + else billed.get("input_tokens", 0) + ) or 0 + usage = NormalizedTokenUsage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, + ) + + adapter._emit_model_invoke( + provider="cohere", + model=model, + usage=usage, + latency_ms=elapsed_ms, + metadata={"request_type": "embedding"}, + ) + adapter._emit_cost_record( + model=model, + usage=usage, + provider="cohere", + ) + except Exception: + logger.warning("Error emitting Cohere embed events", exc_info=True) + + return response + + traced_embed._layerlens_original = original # type: ignore[attr-defined] + return traced_embed + + # --- Token + content extraction --- + + @staticmethod + def _extract_usage( + response: Any, + version: str, # noqa: ARG004 - kept for callsite symmetry; both versions use the same shape + ) -> Optional[NormalizedTokenUsage]: + """Extract token usage from a Cohere chat response. + + Both v1 and v2 expose ``response.meta.billed_units`` and / or + ``response.usage.tokens`` (varies by SDK version). + """ + meta = getattr(response, "meta", None) + if meta is not None: + billed = getattr(meta, "billed_units", None) + if billed is None and isinstance(meta, dict): + billed = meta.get("billed_units") + if billed is not None: + input_tokens = ( + getattr(billed, "input_tokens", 0) + if not isinstance(billed, dict) + else billed.get("input_tokens", 0) + ) or 0 + output_tokens = ( + getattr(billed, "output_tokens", 0) + if not isinstance(billed, dict) + else billed.get("output_tokens", 0) + ) or 0 + return NormalizedTokenUsage( + prompt_tokens=int(input_tokens), + completion_tokens=int(output_tokens), + total_tokens=int(input_tokens) + int(output_tokens), + ) + + # v2 sometimes exposes ``usage.tokens.input_tokens`` / ``output_tokens``. + usage = getattr(response, "usage", None) + if usage is not None: + tokens = getattr(usage, "tokens", None) + if tokens is not None: + input_tokens = getattr(tokens, "input_tokens", 0) or 0 + output_tokens = getattr(tokens, "output_tokens", 0) or 0 + return NormalizedTokenUsage( + prompt_tokens=int(input_tokens), + completion_tokens=int(output_tokens), + total_tokens=int(input_tokens) + int(output_tokens), + ) + + return None + + @staticmethod + def _extract_output_message( + response: Any, version: str + ) -> Optional[Dict[str, str]]: + """Extract the assistant output content.""" + try: + if version == "v1": + # v1 ``response.text`` contains the generated message. + text = getattr(response, "text", None) + if text: + return {"role": "assistant", "content": str(text)[:10_000]} + return None + + # v2: ``response.message.content`` is a list of content blocks. + message = getattr(response, "message", None) + if message is None: + return None + content = getattr(message, "content", None) or [] + parts: List[str] = [] + for block in content: + btype = getattr(block, "type", None) + if btype == "text": + text = getattr(block, "text", "") + if text: + parts.append(str(text)) + if parts: + return {"role": "assistant", "content": "\n".join(parts)[:10_000]} + except Exception: + logger.debug("Error extracting Cohere output message", exc_info=True) + return None + + @staticmethod + def _extract_tool_calls(response: Any, version: str) -> List[Dict[str, Any]]: + """Extract tool calls (function invocations) from the response.""" + calls: List[Dict[str, Any]] = [] + try: + if version == "v1": + # v1: ``response.tool_calls`` is a list of {name, parameters}. + v1_calls = getattr(response, "tool_calls", None) or [] + for tc in v1_calls: + name = getattr(tc, "name", "unknown") + params = getattr(tc, "parameters", None) or {} + calls.append({"name": name, "arguments": params}) + return calls + + # v2: ``response.message.tool_calls`` of {id, function: {name, arguments}}. + message = getattr(response, "message", None) + if message is None: + return calls + v2_calls = getattr(message, "tool_calls", None) or [] + import json as _json + + for tc in v2_calls: + fn = getattr(tc, "function", None) + if fn is None: + continue + args_str = getattr(fn, "arguments", "{}") + try: + args = _json.loads(args_str) if isinstance(args_str, str) else args_str + except (ValueError, TypeError): + args = args_str + calls.append( + { + "name": getattr(fn, "name", "unknown"), + "arguments": args, + "id": getattr(tc, "id", None), + } + ) + except Exception: + logger.debug("Error extracting Cohere tool calls", exc_info=True) + return calls + + +# Registry lazy-loading convention. +ADAPTER_CLASS = CohereAdapter diff --git a/src/layerlens/instrument/adapters/providers/google_vertex_adapter.py b/src/layerlens/instrument/adapters/providers/google_vertex_adapter.py new file mode 100644 index 0000000..543bfbf --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/google_vertex_adapter.py @@ -0,0 +1,356 @@ +"""Google Vertex AI LLM Provider Adapter. + +Wraps ``GenerativeModel.generate_content`` to intercept sync, async, +and streaming calls. Parses function calls from response candidates. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/google_vertex_adapter.py``. +""" + +from __future__ import annotations + +import time +import logging +from typing import Any, Dict, List, Iterator, Optional + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + + +class GoogleVertexAdapter(LLMProviderAdapter): + """LayerLens adapter for the Google Vertex AI (Gemini) SDK. + + Wraps ``GenerativeModel.generate_content`` for sync and streaming. + Extracts tokens from ``usage_metadata`` and function calls from + ``candidates``. + """ + + FRAMEWORK = "google_vertex" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + + def connect_client(self, client: Any) -> Any: + """Wrap a ``GenerativeModel`` instance with tracing. + + Args: + client: A ``google.generativeai.GenerativeModel`` or + ``vertexai.generative_models.GenerativeModel`` instance. + """ + self._client = client + + if hasattr(client, "generate_content"): + original = client.generate_content + self._originals["generate_content"] = original + client.generate_content = self._wrap_generate_content(original) + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + if "generate_content" in self._originals: + try: + self._client.generate_content = self._originals["generate_content"] + except Exception: + logger.warning("Could not restore generate_content") + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import google.generativeai as genai # type: ignore[import-untyped,unused-ignore] + + version = getattr(genai, "__version__", None) + return str(version) if version is not None else None + except ImportError: + pass + try: + import vertexai # type: ignore[import-not-found,import-untyped,unused-ignore] + + version = getattr(vertexai, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + def _wrap_generate_content(self, original: Any) -> Any: + adapter = self + + def traced_generate(*args: Any, **kwargs: Any) -> Any: + model_name = getattr(adapter._client, "model_name", None) or getattr( + adapter._client, "_model_name", None + ) + if model_name and model_name.startswith("models/"): + model_name = model_name[len("models/") :] + is_stream = kwargs.get("stream", False) + start_ns = time.time_ns() + + params: Dict[str, Any] = {} + gen_config = kwargs.get("generation_config") + if gen_config: + if hasattr(gen_config, "temperature"): + params["temperature"] = gen_config.temperature + elif isinstance(gen_config, dict): + params = { + k: gen_config[k] + for k in ("temperature", "max_output_tokens", "top_p", "top_k") + if k in gen_config + } + + input_messages = adapter._normalize_vertex_contents( + args[0] if args else kwargs.get("contents"), + ) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="google_vertex", + model=model_name, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error( + "google_vertex", str(exc), model=model_name + ) + except Exception: + logger.warning("Error emitting Vertex error event", exc_info=True) + raise + + if is_stream: + return adapter._wrap_stream( + response, model_name, params, start_ns, input_messages + ) + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + output_message = adapter._extract_output_text(response) + + metadata: Dict[str, Any] = {} + candidates = getattr(response, "candidates", None) or [] + if candidates: + fr = getattr(candidates[0], "finish_reason", None) + if fr is not None: + fr_name = getattr(fr, "name", None) + metadata["finish_reason"] = ( + fr_name if fr_name is not None else str(fr) + ) + + adapter._emit_model_invoke( + provider="google_vertex", + model=model_name, + parameters=params, + usage=usage, + latency_ms=elapsed_ms, + input_messages=input_messages, + output_message=output_message, + metadata=metadata if metadata else None, + ) + adapter._emit_cost_record( + model=model_name, usage=usage, provider="google_vertex" + ) + + tool_calls = adapter._extract_function_calls(response) + if tool_calls: + adapter._emit_tool_calls(tool_calls, parent_model=model_name) + except Exception: + logger.warning("Error emitting Vertex trace events", exc_info=True) + + return response + + traced_generate._layerlens_original = original # type: ignore[attr-defined] + return traced_generate + + def _wrap_stream( + self, + stream: Any, + model_name: Optional[str], + params: Dict[str, Any], + start_ns: int, + input_messages: Optional[List[Dict[str, str]]] = None, + ) -> Any: + adapter = self + final_usage: Optional[NormalizedTokenUsage] = None + stream_finish_reason: Optional[str] = None + + class TracedStream: + def __init__(self, inner: Any) -> None: + self._inner = inner + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + nonlocal final_usage, stream_finish_reason + try: + chunk = next(self._inner) + except StopIteration: + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + stream_meta: Dict[str, Any] = {"streaming": True} + if stream_finish_reason is not None: + stream_meta["finish_reason"] = stream_finish_reason + adapter._emit_model_invoke( + provider="google_vertex", + model=model_name, + parameters=params, + usage=final_usage, + latency_ms=elapsed_ms, + metadata=stream_meta, + input_messages=input_messages, + ) + if final_usage: + adapter._emit_cost_record( + model=model_name, + usage=final_usage, + provider="google_vertex", + ) + except Exception: + logger.warning( + "Error emitting Vertex stream events", exc_info=True + ) + raise + + try: + chunk_usage = adapter._extract_usage(chunk) + if chunk_usage: + final_usage = chunk_usage + chunk_candidates = getattr(chunk, "candidates", None) or [] + if chunk_candidates: + fr = getattr(chunk_candidates[0], "finish_reason", None) + if fr is not None: + fr_name = getattr(fr, "name", None) + stream_finish_reason = ( + fr_name if fr_name is not None else str(fr) + ) + except Exception: + logger.debug("Error extracting Vertex stream usage", exc_info=True) + return chunk + + return TracedStream(stream) + + @staticmethod + def _extract_usage(response: Any) -> Optional[NormalizedTokenUsage]: + """Extract token usage from a Vertex response's ``usage_metadata``.""" + metadata = getattr(response, "usage_metadata", None) + if not metadata: + return None + prompt = getattr(metadata, "prompt_token_count", 0) or 0 + completion = getattr(metadata, "candidates_token_count", 0) or 0 + total = getattr(metadata, "total_token_count", 0) or (prompt + completion) + reasoning = getattr(metadata, "thoughts_token_count", None) + + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=total, + reasoning_tokens=reasoning, + ) + + @staticmethod + def _normalize_vertex_contents(contents: Any) -> Optional[List[Dict[str, str]]]: + """Normalize Vertex AI contents to ``[{role, content}]``.""" + if contents is None: + return None + try: + messages: List[Dict[str, str]] = [] + if isinstance(contents, str): + messages.append({"role": "user", "content": contents[:10_000]}) + return messages + if isinstance(contents, list): + for item in contents: + if isinstance(item, str): + messages.append({"role": "user", "content": item[:10_000]}) + elif hasattr(item, "role") and hasattr(item, "parts"): + role = str(getattr(item, "role", "user")) + parts_text: List[str] = [] + for part in getattr(item, "parts", []): + text = getattr(part, "text", None) + if text: + parts_text.append(str(text)) + if parts_text: + messages.append( + {"role": role, "content": "\n".join(parts_text)[:10_000]} + ) + elif isinstance(item, dict): + role = str(item.get("role", "user")) + parts = item.get("parts", []) + parts_text2: List[str] = [] + for p in parts: + if isinstance(p, str): + parts_text2.append(p) + elif isinstance(p, dict) and "text" in p: + parts_text2.append(str(p["text"])) + if parts_text2: + messages.append( + { + "role": role, + "content": "\n".join(parts_text2)[:10_000], + } + ) + return messages if messages else None + except Exception: + logger.debug("Error normalizing Vertex contents", exc_info=True) + return None + + @staticmethod + def _extract_output_text(response: Any) -> Optional[Dict[str, str]]: + """Extract output text from a Vertex response.""" + try: + candidates = getattr(response, "candidates", None) or [] + if not candidates: + return None + content = getattr(candidates[0], "content", None) + if not content: + return None + parts = getattr(content, "parts", None) or [] + texts: List[str] = [] + for part in parts: + text = getattr(part, "text", None) + if text: + texts.append(str(text)) + if texts: + return {"role": "model", "content": "\n".join(texts)[:10_000]} + except Exception: + logger.debug("Error extracting Vertex output text", exc_info=True) + return None + + @staticmethod + def _extract_function_calls(response: Any) -> List[Dict[str, Any]]: + """Extract function calls from Vertex response candidates.""" + tool_calls: List[Dict[str, Any]] = [] + try: + candidates = getattr(response, "candidates", None) or [] + if not candidates: + return tool_calls + content = getattr(candidates[0], "content", None) + if not content: + return tool_calls + parts = getattr(content, "parts", None) or [] + for part in parts: + fn_call = getattr(part, "function_call", None) + if fn_call: + tool_calls.append( + { + "name": getattr(fn_call, "name", "unknown"), + "arguments": dict(getattr(fn_call, "args", {}) or {}), + } + ) + except Exception: + logger.debug("Error extracting Vertex function calls", exc_info=True) + return tool_calls + + +# Registry lazy-loading convention. +ADAPTER_CLASS = GoogleVertexAdapter diff --git a/src/layerlens/instrument/adapters/providers/litellm_adapter.py b/src/layerlens/instrument/adapters/providers/litellm_adapter.py new file mode 100644 index 0000000..af1e121 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/litellm_adapter.py @@ -0,0 +1,359 @@ +"""LiteLLM Provider Adapter. + +Uses the LiteLLM callback handler pattern (not monkey-patch). Registers +:class:`LayerLensLiteLLMCallback` via ``litellm.callbacks``. Auto-detects +the underlying provider from the model-string prefix. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/litellm_adapter.py``. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from layerlens.instrument.adapters._base.adapter import AdapterStatus +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + +# Model prefix → provider mapping. +_PROVIDER_PREFIXES: Dict[str, str] = { + "openai/": "openai", + "anthropic/": "anthropic", + "azure/": "azure_openai", + "bedrock/": "aws_bedrock", + "vertex_ai/": "google_vertex", + "ollama/": "ollama", + "cohere/": "cohere", + "huggingface/": "huggingface", + "together_ai/": "together_ai", + "groq/": "groq", +} + + +def detect_provider(model_str: str) -> str: + """Detect the underlying provider from a LiteLLM model string.""" + if not model_str: + return "unknown" + for prefix, provider in _PROVIDER_PREFIXES.items(): + if model_str.startswith(prefix): + return provider + lower = model_str.lower() + if lower.startswith("gpt-") or lower.startswith("o1") or lower.startswith("o3"): + return "openai" + if lower.startswith("claude-"): + return "anthropic" + if lower.startswith("gemini-"): + return "google_vertex" + if lower.startswith("llama"): + return "meta" + if lower.startswith("mistral"): + return "mistral" + return "unknown" + + +class LayerLensLiteLLMCallback: + """LiteLLM callback handler that emits LayerLens events. + + Registered via ``litellm.callbacks``. Implements + :meth:`log_success_event`, :meth:`log_failure_event`, and + :meth:`log_stream_event`. + """ + + def __init__(self, adapter: "LiteLLMAdapter") -> None: + self._adapter = adapter + + def log_success_event( + self, + kwargs: Dict[str, Any], + response_obj: Any, + start_time: Any, + end_time: Any, + ) -> None: + """Emit ``model.invoke`` and ``cost.record`` on successful completion.""" + try: + model = kwargs.get("model", "") + provider = detect_provider(model) + latency_ms = self._calc_latency_ms(start_time, end_time) + usage = self._extract_usage(response_obj) + + input_messages = self._adapter._normalize_messages(kwargs.get("messages")) + output_message = self._extract_output_message(response_obj) + + metadata: Dict[str, Any] = {} + if response_obj is not None: + choices = getattr(response_obj, "choices", None) or [] + if choices: + fr = getattr(choices[0], "finish_reason", None) + if fr is not None: + metadata["finish_reason"] = fr + resp_id = getattr(response_obj, "id", None) + if resp_id is not None: + metadata["response_id"] = resp_id + resp_model = getattr(response_obj, "model", None) + if resp_model is not None: + metadata["response_model"] = resp_model + + self._adapter._emit_model_invoke( + provider=provider, + model=model, + parameters=self._extract_params(kwargs), + usage=usage, + latency_ms=latency_ms, + input_messages=input_messages, + output_message=output_message, + metadata=metadata if metadata else None, + ) + + cost = self._get_litellm_cost(kwargs, response_obj) + if cost is not None: + self._adapter.emit_dict_event( + "cost.record", + { + "provider": provider, + "model": model, + "prompt_tokens": usage.prompt_tokens if usage else 0, + "completion_tokens": usage.completion_tokens if usage else 0, + "total_tokens": usage.total_tokens if usage else 0, + "api_cost_usd": cost, + "cost_source": "litellm", + }, + ) + elif usage: + self._adapter._emit_cost_record( + model=model, + usage=usage, + provider=provider, + ) + except Exception: + logger.warning("Error in LiteLLM success callback", exc_info=True) + + def log_failure_event( + self, + kwargs: Dict[str, Any], + response_obj: Any, # noqa: ARG002 - LiteLLM callback signature requires this arg + start_time: Any, + end_time: Any, + ) -> None: + """Emit ``model.invoke`` with error on failed completion.""" + try: + model = kwargs.get("model", "") + provider = detect_provider(model) + latency_ms = self._calc_latency_ms(start_time, end_time) + error = kwargs.get("exception", "") + + input_messages = self._adapter._normalize_messages(kwargs.get("messages")) + + self._adapter._emit_model_invoke( + provider=provider, + model=model, + parameters=self._extract_params(kwargs), + latency_ms=latency_ms, + error=str(error), + input_messages=input_messages, + ) + self._adapter._emit_provider_error(provider, str(error), model=model) + except Exception: + logger.warning("Error in LiteLLM failure callback", exc_info=True) + + def log_stream_event( + self, + kwargs: Dict[str, Any], + response_obj: Any, + start_time: Any, + end_time: Any, + ) -> None: + """Emit ``model.invoke`` when the stream completes.""" + try: + model = kwargs.get("model", "") + provider = detect_provider(model) + latency_ms = self._calc_latency_ms(start_time, end_time) + usage = self._extract_usage(response_obj) + + input_messages = self._adapter._normalize_messages(kwargs.get("messages")) + + stream_meta: Dict[str, Any] = {"streaming": True} + if response_obj is not None: + choices = getattr(response_obj, "choices", None) or [] + if choices: + fr = getattr(choices[0], "finish_reason", None) + if fr is not None: + stream_meta["finish_reason"] = fr + resp_id = getattr(response_obj, "id", None) + if resp_id is not None: + stream_meta["response_id"] = resp_id + + self._adapter._emit_model_invoke( + provider=provider, + model=model, + usage=usage, + latency_ms=latency_ms, + metadata=stream_meta, + input_messages=input_messages, + ) + + if usage: + self._adapter._emit_cost_record( + model=model, + usage=usage, + provider=provider, + ) + except Exception: + logger.warning("Error in LiteLLM stream callback", exc_info=True) + + @staticmethod + def _calc_latency_ms(start_time: Any, end_time: Any) -> Optional[float]: + if start_time is None or end_time is None: + return None + try: + if hasattr(start_time, "timestamp"): + return float((end_time.timestamp() - start_time.timestamp()) * 1000) + return float(end_time - start_time) * 1000 + except Exception: + return None + + @staticmethod + def _extract_usage(response_obj: Any) -> Optional[NormalizedTokenUsage]: + if response_obj is None: + return None + usage = getattr(response_obj, "usage", None) + if usage is None: + return None + prompt = getattr(usage, "prompt_tokens", 0) or 0 + completion = getattr(usage, "completion_tokens", 0) or 0 + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + + @staticmethod + def _extract_output_message(response_obj: Any) -> Optional[Dict[str, str]]: + """Extract output message from a LiteLLM response (OpenAI-compatible).""" + try: + if response_obj is None: + return None + choices = getattr(response_obj, "choices", None) or [] + if not choices: + return None + message = getattr(choices[0], "message", None) + if not message: + return None + content = getattr(message, "content", None) + if content: + return {"role": "assistant", "content": str(content)[:10_000]} + except Exception: + pass + return None + + @staticmethod + def _extract_params(kwargs: Dict[str, Any]) -> Dict[str, Any]: + params: Dict[str, Any] = {} + for key in ("temperature", "max_tokens", "top_p"): + if key in kwargs: + params[key] = kwargs[key] + opt = kwargs.get("optional_params", {}) + if isinstance(opt, dict): + for key in ("temperature", "max_tokens", "top_p"): + if key in opt and key not in params: + params[key] = opt[key] + return params + + @staticmethod + def _get_litellm_cost( + kwargs: Dict[str, Any], + response_obj: Any, + ) -> Optional[float]: + """Try to get cost from LiteLLM's built-in cost tracking.""" + try: + import litellm # type: ignore[import-not-found,import-untyped,unused-ignore] + + cost = litellm.completion_cost( + model=kwargs.get("model", ""), + completion_response=response_obj, + ) + return float(cost) if cost else None + except Exception: + return None + + +class LiteLLMAdapter(LLMProviderAdapter): + """LayerLens adapter for LiteLLM. + + Uses LiteLLM's callback handler pattern instead of monkey-patching. + Auto-detects the underlying provider from the model-string prefix. + """ + + FRAMEWORK = "litellm" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._callback: Optional[LayerLensLiteLLMCallback] = None + + def connect(self) -> None: + """Register the LayerLens callback with LiteLLM.""" + self._callback = LayerLensLiteLLMCallback(self) + try: + import litellm # type: ignore[import-not-found,unused-ignore] + + # ``litellm.callbacks`` is typed as ``list[Callable]`` by upstream + # stubs but accepts handler classes by convention. Cast through + # ``Any`` at the boundary to satisfy strict type checkers. + callbacks: Any = getattr(litellm, "callbacks", None) + if callbacks is None: + callbacks = [] + litellm.callbacks = callbacks + callbacks.append(self._callback) + version = getattr(litellm, "__version__", None) + self._framework_version = str(version) if version is not None else None + self._connected = True + self._status = AdapterStatus.HEALTHY + except ImportError: + logger.warning("LiteLLM not installed; adapter in degraded mode") + self._connected = True + self._status = AdapterStatus.DEGRADED + + def disconnect(self) -> None: + """Remove the LayerLens callback from LiteLLM.""" + if self._callback: + try: + import litellm # type: ignore[import-not-found,unused-ignore] + + callbacks: Any = getattr(litellm, "callbacks", None) + if callbacks is not None and self._callback in callbacks: + callbacks.remove(self._callback) + except ImportError: + pass + self._callback = None + self._connected = False + self._status = AdapterStatus.DISCONNECTED + + def connect_client(self, client: Any) -> Any: + """LiteLLM uses callbacks, not client wrapping — no-op.""" + return client + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import litellm # type: ignore[import-not-found,unused-ignore] + + version = getattr(litellm, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + +# Backward-compat alias for users coming from ateam. +STRATIXLiteLLMCallback = LayerLensLiteLLMCallback + + +# Registry lazy-loading convention. +ADAPTER_CLASS = LiteLLMAdapter diff --git a/src/layerlens/instrument/adapters/providers/mistral_adapter.py b/src/layerlens/instrument/adapters/providers/mistral_adapter.py new file mode 100644 index 0000000..79babcd --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/mistral_adapter.py @@ -0,0 +1,449 @@ +"""Mistral AI LLM Provider Adapter. + +Wraps the Mistral Python SDK (``mistralai`` >= 1.x) to intercept +``client.chat.complete`` and ``client.chat.stream`` calls. Emits +``model.invoke``, ``cost.record``, ``tool.call``, and +``policy.violation`` events. + +Fresh-built (Mistral did not have an adapter in ``ateam`` source as of +2026-04-25). Follows the OpenAI / Anthropic adapter contract. +""" + +from __future__ import annotations + +import time +import logging +from typing import Any, Dict, List, Iterator, Optional + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + +_CAPTURE_PARAMS = frozenset( + { + "model", + "temperature", + "max_tokens", + "top_p", + "random_seed", + "response_format", + "tool_choice", + "safe_prompt", + } +) + + +class MistralAdapter(LLMProviderAdapter): + """LayerLens adapter for the Mistral AI Python SDK. + + Usage:: + + from mistralai import Mistral + from layerlens.instrument.adapters.providers.mistral_adapter import MistralAdapter + + adapter = MistralAdapter() + adapter.connect() + + client = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) + adapter.connect_client(client) + + client.chat.complete( + model="mistral-small", + messages=[{"role": "user", "content": "Hello"}], + ) + """ + + FRAMEWORK = "mistral" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + + def connect_client(self, client: Any) -> Any: + """Wrap ``client.chat.complete`` and ``client.chat.stream``.""" + self._client = client + + chat = getattr(client, "chat", None) + if chat is None: + return client + + if hasattr(chat, "complete") and callable(chat.complete): + original_complete = chat.complete + self._originals["chat.complete"] = original_complete + chat.complete = self._wrap_complete(original_complete) + + if hasattr(chat, "stream") and callable(chat.stream): + original_stream = chat.stream + self._originals["chat.stream"] = original_stream + chat.stream = self._wrap_stream_method(original_stream) + + # Embedding endpoint is at ``client.embeddings.create``. + embeddings = getattr(client, "embeddings", None) + if embeddings is not None and hasattr(embeddings, "create"): + original_embed = embeddings.create + self._originals["embeddings.create"] = original_embed + embeddings.create = self._wrap_embed(original_embed) + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + chat = getattr(self._client, "chat", None) + if chat is not None: + if "chat.complete" in self._originals: + try: + chat.complete = self._originals["chat.complete"] + except Exception: + logger.warning("Could not restore chat.complete") + if "chat.stream" in self._originals: + try: + chat.stream = self._originals["chat.stream"] + except Exception: + logger.warning("Could not restore chat.stream") + embeddings = getattr(self._client, "embeddings", None) + if embeddings is not None and "embeddings.create" in self._originals: + try: + embeddings.create = self._originals["embeddings.create"] + except Exception: + logger.warning("Could not restore embeddings.create") + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import mistralai # type: ignore[import-not-found,import-untyped,unused-ignore] + + version = getattr(mistralai, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + # --- Wrapping --- + + def _wrap_complete(self, original: Any) -> Any: + adapter = self + + def traced_complete(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="mistral", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error("mistral", str(exc), model=model) + except Exception: + logger.warning("Error emitting Mistral error event", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + output_message = adapter._extract_output_message(response) + + metadata: Dict[str, Any] = {} + resp_id = getattr(response, "id", None) + if resp_id is not None: + metadata["response_id"] = resp_id + resp_model = getattr(response, "model", None) + if resp_model is not None: + metadata["response_model"] = resp_model + choices = getattr(response, "choices", None) or [] + if choices: + finish = getattr(choices[0], "finish_reason", None) + if finish is not None: + metadata["finish_reason"] = str(finish) + + adapter._emit_model_invoke( + provider="mistral", + model=model, + parameters=params, + usage=usage, + latency_ms=elapsed_ms, + input_messages=input_messages, + output_message=output_message, + metadata=metadata if metadata else None, + ) + adapter._emit_cost_record( + model=model, + usage=usage, + provider="mistral", + ) + + tool_calls = adapter._extract_tool_calls(response) + if tool_calls: + adapter._emit_tool_calls(tool_calls, parent_model=model) + except Exception: + logger.warning("Error emitting Mistral trace events", exc_info=True) + + return response + + traced_complete._layerlens_original = original # type: ignore[attr-defined] + return traced_complete + + def _wrap_stream_method(self, original: Any) -> Any: + """Wrap ``client.chat.stream`` to emit one consolidated event on completion.""" + adapter = self + + def traced_stream(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + start_ns = time.time_ns() + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + stream = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="mistral", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error("mistral", str(exc), model=model) + except Exception: + logger.warning("Error emitting Mistral stream error", exc_info=True) + raise + + return _MistralTracedStream( + adapter, stream, model, params, start_ns, input_messages + ) + + traced_stream._layerlens_original = original # type: ignore[attr-defined] + return traced_stream + + def _wrap_embed(self, original: Any) -> Any: + adapter = self + + def traced_embed(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + start_ns = time.time_ns() + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="mistral", + model=model, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"request_type": "embedding"}, + ) + except Exception: + logger.warning("Error emitting Mistral embed error", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + adapter._emit_model_invoke( + provider="mistral", + model=model, + usage=usage, + latency_ms=elapsed_ms, + metadata={"request_type": "embedding"}, + ) + adapter._emit_cost_record( + model=model, + usage=usage, + provider="mistral", + ) + except Exception: + logger.warning("Error emitting Mistral embed events", exc_info=True) + + return response + + traced_embed._layerlens_original = original # type: ignore[attr-defined] + return traced_embed + + # --- Token + content extraction --- + + @staticmethod + def _extract_usage(response: Any) -> Optional[NormalizedTokenUsage]: + """Extract usage from a Mistral response (``response.usage``).""" + usage = getattr(response, "usage", None) + if usage is None: + return None + prompt = getattr(usage, "prompt_tokens", 0) or 0 + completion = getattr(usage, "completion_tokens", 0) or 0 + total = getattr(usage, "total_tokens", 0) or (prompt + completion) + return NormalizedTokenUsage( + prompt_tokens=int(prompt), + completion_tokens=int(completion), + total_tokens=int(total), + ) + + @staticmethod + def _extract_output_message(response: Any) -> Optional[Dict[str, str]]: + try: + choices = getattr(response, "choices", None) or [] + if not choices: + return None + message = getattr(choices[0], "message", None) + if message is None: + return None + content = getattr(message, "content", None) + if content: + return {"role": "assistant", "content": str(content)[:10_000]} + except Exception: + logger.debug("Error extracting Mistral output message", exc_info=True) + return None + + @staticmethod + def _extract_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool_calls from a Mistral response (OpenAI-compatible shape).""" + import json as _json + + calls: List[Dict[str, Any]] = [] + try: + choices = getattr(response, "choices", None) or [] + if not choices: + return calls + message = getattr(choices[0], "message", None) + if message is None: + return calls + tool_calls = getattr(message, "tool_calls", None) or [] + for tc in tool_calls: + fn = getattr(tc, "function", None) + if fn is None: + continue + args_str = getattr(fn, "arguments", "{}") + try: + args = _json.loads(args_str) if isinstance(args_str, str) else args_str + except (ValueError, TypeError): + args = args_str + calls.append( + { + "name": getattr(fn, "name", "unknown"), + "arguments": args, + "id": getattr(tc, "id", None), + } + ) + except Exception: + logger.debug("Error extracting Mistral tool calls", exc_info=True) + return calls + + +class _MistralTracedStream: + """Wrap a Mistral chat stream to emit one consolidated ``model.invoke``. + + The Mistral SDK's stream returns a generator of ``CompletionEvent`` + objects with ``data.choices[0].delta.content`` text fragments. We + accumulate content and tool-call deltas, then emit on iterator + exhaustion (``StopIteration``). + """ + + def __init__( + self, + adapter: MistralAdapter, + inner: Any, + model: Optional[str], + params: Dict[str, Any], + start_ns: int, + input_messages: Optional[List[Dict[str, str]]], + ) -> None: + self._adapter = adapter + self._inner = iter(inner) + self._model = model + self._params = params + self._start_ns = start_ns + self._input_messages = input_messages + self._content: List[str] = [] + self._final_usage: Optional[NormalizedTokenUsage] = None + self._finish_reason: Optional[str] = None + self._response_id: Optional[str] = None + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + try: + event = next(self._inner) + except StopIteration: + self._emit_consolidated() + raise + try: + self._absorb_event(event) + except Exception: + logger.debug("Error absorbing Mistral stream event", exc_info=True) + return event + + def _absorb_event(self, event: Any) -> None: + data = getattr(event, "data", event) + resp_id = getattr(data, "id", None) + if resp_id is not None: + self._response_id = str(resp_id) + choices = getattr(data, "choices", None) or [] + for choice in choices: + delta = getattr(choice, "delta", None) + if delta is not None: + content = getattr(delta, "content", None) + if content: + self._content.append(str(content)) + finish = getattr(choice, "finish_reason", None) + if finish is not None: + self._finish_reason = str(finish) + usage = getattr(data, "usage", None) + if usage is not None: + self._final_usage = MistralAdapter._extract_usage(data) + + def _emit_consolidated(self) -> None: + try: + elapsed_ms = (time.time_ns() - self._start_ns) / 1_000_000 + output_message: Optional[Dict[str, str]] = None + if self._content: + output_message = { + "role": "assistant", + "content": "".join(self._content)[:10_000], + } + metadata: Dict[str, Any] = {"streaming": True} + if self._finish_reason: + metadata["finish_reason"] = self._finish_reason + if self._response_id: + metadata["response_id"] = self._response_id + self._adapter._emit_model_invoke( + provider="mistral", + model=self._model, + parameters=self._params, + usage=self._final_usage, + latency_ms=elapsed_ms, + input_messages=self._input_messages, + output_message=output_message, + metadata=metadata, + ) + if self._final_usage: + self._adapter._emit_cost_record( + model=self._model, + usage=self._final_usage, + provider="mistral", + ) + except Exception: + logger.warning("Error emitting Mistral stream events", exc_info=True) + + +# Registry lazy-loading convention. +ADAPTER_CLASS = MistralAdapter diff --git a/src/layerlens/instrument/adapters/providers/ollama_adapter.py b/src/layerlens/instrument/adapters/providers/ollama_adapter.py new file mode 100644 index 0000000..84facb2 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/ollama_adapter.py @@ -0,0 +1,261 @@ +"""Ollama LLM Provider Adapter. + +Wraps the Ollama Python SDK to intercept ``chat``, ``generate``, and +``embeddings`` calls. All API costs are $0.00 (local). Optional infra +cost tracking via compute duration when ``cost_per_second`` is set. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/ollama_adapter.py``. +""" + +from __future__ import annotations + +import os +import time +import logging +from typing import Any, Dict, Optional + +from layerlens.instrument.adapters._base.adapter import AdapterStatus +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + + +class OllamaAdapter(LLMProviderAdapter): + """LayerLens adapter for the Ollama Python SDK. + + Wraps ``ollama.chat()``, ``ollama.generate()``, and + ``ollama.embeddings()`` calls. API cost is always $0.00 (local + inference). Optionally tracks infra cost from compute duration if + ``cost_per_second`` is configured. + """ + + FRAMEWORK = "ollama" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + cost_per_second: Optional[float] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + self._cost_per_second = cost_per_second + self._endpoint: Optional[str] = None + + def connect(self) -> None: + """Detect Ollama endpoint and mark as connected.""" + self._endpoint = os.environ.get("OLLAMA_HOST", "http://localhost:11434") + self._framework_version = self._detect_framework_version() + self._connected = True + self._status = AdapterStatus.HEALTHY + + def connect_client(self, client: Any) -> Any: + """Wrap Ollama client / module methods with tracing.""" + self._client = client + + if hasattr(client, "chat"): + original_chat = client.chat + self._originals["chat"] = original_chat + client.chat = self._wrap_call(original_chat, "chat") + + if hasattr(client, "generate"): + original_gen = client.generate + self._originals["generate"] = original_gen + client.generate = self._wrap_call(original_gen, "generate") + + if hasattr(client, "embeddings"): + original_embed = client.embeddings + self._originals["embeddings"] = original_embed + client.embeddings = self._wrap_call(original_embed, "embeddings") + + return client + + def _restore_originals(self) -> None: + if self._client is None: + return + for method_name, original in self._originals.items(): + try: + setattr(self._client, method_name, original) + except Exception: + logger.warning("Could not restore %s", method_name) + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import ollama # type: ignore[import-not-found,unused-ignore] + + version = getattr(ollama, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + def _wrap_call(self, original: Any, method_name: str) -> Any: + adapter = self + + def traced_call(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") or (args[0] if args else None) + start_ns = time.time_ns() + + input_messages = None + if method_name == "chat": + input_messages = adapter._normalize_messages(kwargs.get("messages")) + elif method_name == "generate": + prompt = kwargs.get("prompt") + if prompt: + input_messages = [ + {"role": "user", "content": str(prompt)[:10_000]} + ] + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="ollama", + model=model, + latency_ms=elapsed_ms, + error=str(exc), + metadata={ + "method": method_name, + "endpoint": adapter._endpoint, + }, + input_messages=input_messages, + ) + adapter._emit_provider_error("ollama", str(exc), model=model) + except Exception: + logger.warning("Error emitting Ollama error event", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + infra_cost = adapter._calculate_infra_cost(response) + output_message = adapter._extract_output_message(response, method_name) + + ollama_metadata: Dict[str, Any] = { + "method": method_name, + "endpoint": adapter._endpoint, + } + if isinstance(response, dict): + done_reason = response.get("done_reason") + else: + done_reason = getattr(response, "done_reason", None) + if done_reason is not None: + ollama_metadata["finish_reason"] = done_reason + + adapter._emit_model_invoke( + provider="ollama", + model=model, + usage=usage, + latency_ms=elapsed_ms, + metadata=ollama_metadata, + input_messages=input_messages, + output_message=output_message, + ) + + cost_meta: Dict[str, Any] = {"api_cost_usd": 0.0} + if infra_cost is not None: + cost_meta["infra_cost_usd"] = infra_cost + + adapter.emit_dict_event( + "cost.record", + { + "provider": "ollama", + "model": model, + "prompt_tokens": usage.prompt_tokens if usage else 0, + "completion_tokens": usage.completion_tokens if usage else 0, + "total_tokens": usage.total_tokens if usage else 0, + **cost_meta, + }, + ) + except Exception: + logger.warning("Error emitting Ollama trace events", exc_info=True) + + return response + + traced_call._layerlens_original = original # type: ignore[attr-defined] + return traced_call + + @staticmethod + def _extract_usage(response: Any) -> Optional[NormalizedTokenUsage]: + """Extract token usage from an Ollama response.""" + if response is None: + return None + if isinstance(response, dict): + prompt = response.get("prompt_eval_count", 0) or 0 + completion = response.get("eval_count", 0) or 0 + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + prompt = getattr(response, "prompt_eval_count", 0) or 0 + completion = getattr(response, "eval_count", 0) or 0 + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=prompt + completion, + ) + + @staticmethod + def _extract_output_message( + response: Any, method_name: str + ) -> Optional[Dict[str, str]]: + """Extract the output message from an Ollama response.""" + try: + if response is None: + return None + if method_name == "chat": + msg = ( + response.get("message", {}) + if isinstance(response, dict) + else getattr(response, "message", None) + ) + if msg: + content = ( + msg.get("content", "") + if isinstance(msg, dict) + else getattr(msg, "content", "") + ) + if content: + return {"role": "assistant", "content": str(content)[:10_000]} + elif method_name == "generate": + text = ( + response.get("response", "") + if isinstance(response, dict) + else getattr(response, "response", "") + ) + if text: + return {"role": "assistant", "content": str(text)[:10_000]} + except Exception: + pass + return None + + def _calculate_infra_cost(self, response: Any) -> Optional[float]: + """Calculate optional infrastructure cost from compute duration.""" + if self._cost_per_second is None: + return None + if response is None: + return None + + total_ns = 0 + if isinstance(response, dict): + total_ns = (response.get("eval_duration", 0) or 0) + ( + response.get("prompt_eval_duration", 0) or 0 + ) + else: + total_ns = (getattr(response, "eval_duration", 0) or 0) + ( + getattr(response, "prompt_eval_duration", 0) or 0 + ) + + if total_ns > 0: + total_seconds = total_ns / 1_000_000_000 + return round(total_seconds * self._cost_per_second, 8) + return None + + +# Registry lazy-loading convention. +ADAPTER_CLASS = OllamaAdapter diff --git a/src/layerlens/instrument/adapters/providers/openai_adapter.py b/src/layerlens/instrument/adapters/providers/openai_adapter.py new file mode 100644 index 0000000..1e8be31 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/openai_adapter.py @@ -0,0 +1,467 @@ +"""OpenAI LLM Provider Adapter. + +Wraps the OpenAI Python SDK client to intercept chat completions, +embeddings, and streaming calls. Emits ``model.invoke``, +``cost.record``, ``tool.call``, and ``policy.violation`` events. + +Ported from ``ateam/stratix/sdk/python/adapters/llm_providers/openai_adapter.py``. +""" + +from __future__ import annotations + +import json +import time +import logging +from typing import Any, Dict, List, Iterator, Optional + +from layerlens.instrument.adapters._base.capture import CaptureConfig +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers._base.provider import LLMProviderAdapter + +logger = logging.getLogger(__name__) + +# Parameters to capture from request kwargs. +_CAPTURE_PARAMS = frozenset( + { + "model", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "response_format", + "tool_choice", + } +) + + +class OpenAIAdapter(LLMProviderAdapter): + """LayerLens adapter for the OpenAI Python SDK. + + Wraps ``client.chat.completions.create`` and + ``client.embeddings.create`` to emit ``model.invoke``, + ``cost.record``, and ``tool.call`` events. + + Usage:: + + from openai import OpenAI + from layerlens.instrument.adapters.providers.openai_adapter import OpenAIAdapter + + adapter = OpenAIAdapter() + adapter.connect() + + client = OpenAI() + adapter.connect_client(client) + + # Now every client.chat.completions.create() call is instrumented. + """ + + FRAMEWORK = "openai" + VERSION = "0.1.0" + + def __init__( + self, + stratix: Any = None, + capture_config: Optional[CaptureConfig] = None, + ) -> None: + super().__init__(stratix=stratix, capture_config=capture_config) + + def connect_client(self, client: Any) -> Any: + """Wrap OpenAI client methods with tracing.""" + self._client = client + + if hasattr(client, "chat") and hasattr(client.chat, "completions"): + original_create = client.chat.completions.create + self._originals["chat.completions.create"] = original_create + client.chat.completions.create = self._wrap_chat_create(original_create) + + if hasattr(client, "embeddings"): + original_embed = client.embeddings.create + self._originals["embeddings.create"] = original_embed + client.embeddings.create = self._wrap_embeddings_create(original_embed) + + return client + + def _restore_originals(self) -> None: + """Restore original methods on the client.""" + if self._client is None: + return + if "chat.completions.create" in self._originals: + try: + self._client.chat.completions.create = self._originals["chat.completions.create"] + except Exception: + logger.warning("Could not restore chat.completions.create") + if "embeddings.create" in self._originals: + try: + self._client.embeddings.create = self._originals["embeddings.create"] + except Exception: + logger.warning("Could not restore embeddings.create") + + @staticmethod + def _detect_framework_version() -> Optional[str]: + try: + import openai + + version = getattr(openai, "__version__", None) + return str(version) if version is not None else None + except ImportError: + return None + + # --- Wrapping methods --- + + def _wrap_chat_create(self, original: Any) -> Any: + adapter = self + + def traced_create(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + params = {k: kwargs[k] for k in _CAPTURE_PARAMS if k in kwargs} + is_stream = kwargs.get("stream", False) + start_ns = time.time_ns() + + input_messages = adapter._normalize_messages(kwargs.get("messages")) + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="openai", + model=model, + parameters=params, + latency_ms=elapsed_ms, + error=str(exc), + input_messages=input_messages, + ) + adapter._emit_provider_error("openai", str(exc), model=model) + except Exception: + logger.warning("Error emitting OpenAI error event", exc_info=True) + raise + + if is_stream: + return adapter._wrap_stream(response, model, params, start_ns, input_messages) + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + output_message = adapter._extract_output_message(response) + + metadata: Dict[str, Any] = {} + choices = getattr(response, "choices", None) or [] + if choices: + fr = getattr(choices[0], "finish_reason", None) + if fr is not None: + metadata["finish_reason"] = fr + resp_id = getattr(response, "id", None) + if resp_id is not None: + metadata["response_id"] = resp_id + resp_model = getattr(response, "model", None) + if resp_model is not None: + metadata["response_model"] = resp_model + sys_fp = getattr(response, "system_fingerprint", None) + if sys_fp is not None: + metadata["system_fingerprint"] = sys_fp + svc_tier = getattr(response, "service_tier", None) + if svc_tier is not None: + metadata["service_tier"] = svc_tier + seed = kwargs.get("seed") + if seed is not None: + metadata["seed"] = seed + + adapter._emit_model_invoke( + provider="openai", + model=model, + parameters=params, + usage=usage, + latency_ms=elapsed_ms, + input_messages=input_messages, + output_message=output_message, + metadata=metadata if metadata else None, + ) + adapter._emit_cost_record(model=model, usage=usage, provider="openai") + + tool_calls = adapter._extract_tool_calls(response) + if tool_calls: + adapter._emit_tool_calls(tool_calls, parent_model=model) + except Exception: + logger.warning("Error emitting OpenAI trace events", exc_info=True) + + return response + + traced_create._layerlens_original = original # type: ignore[attr-defined] + return traced_create + + def _wrap_stream( + self, + stream: Any, + model: Optional[str], + params: Dict[str, Any], + start_ns: int, + input_messages: Optional[List[Dict[str, str]]] = None, + ) -> Any: + """Wrap a streaming response to accumulate chunks and emit on completion.""" + adapter = self + accumulated_content: List[str] = [] + accumulated_tool_calls: Dict[int, Dict[str, Any]] = {} + final_usage: Optional[NormalizedTokenUsage] = None + stream_finish_reason: Optional[str] = None + stream_response_id: Optional[str] = None + stream_response_model: Optional[str] = None + stream_system_fingerprint: Optional[str] = None + + class TracedStream: + """Wrapper that intercepts stream iteration.""" + + def __init__(self, inner: Any) -> None: + self._inner = inner + + def __iter__(self) -> Iterator[Any]: + return self + + def __next__(self) -> Any: + try: + chunk = next(self._inner) + except StopIteration: + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + output_msg: Optional[Dict[str, str]] = None + if accumulated_content: + output_msg = { + "role": "assistant", + "content": "".join(accumulated_content)[:10_000], + } + stream_meta: Dict[str, Any] = {"streaming": True} + if stream_finish_reason is not None: + stream_meta["finish_reason"] = stream_finish_reason + if stream_response_id is not None: + stream_meta["response_id"] = stream_response_id + if stream_response_model is not None: + stream_meta["response_model"] = stream_response_model + if stream_system_fingerprint is not None: + stream_meta["system_fingerprint"] = stream_system_fingerprint + adapter._emit_model_invoke( + provider="openai", + model=model, + parameters=params, + usage=final_usage, + latency_ms=elapsed_ms, + metadata=stream_meta, + input_messages=input_messages, + output_message=output_msg, + ) + if final_usage: + adapter._emit_cost_record( + model=model, + usage=final_usage, + provider="openai", + ) + if accumulated_tool_calls: + tcs = [ + { + "name": tc.get("name", ""), + "arguments": tc.get("arguments", ""), + "id": tc.get("id"), + } + for tc in accumulated_tool_calls.values() + ] + adapter._emit_tool_calls(tcs, parent_model=model) + except Exception: + logger.warning("Error emitting OpenAI stream events", exc_info=True) + raise + + try: + self._process_chunk(chunk) + except Exception: + logger.debug("Error processing OpenAI stream chunk", exc_info=True) + return chunk + + def _process_chunk(self, chunk: Any) -> None: + nonlocal final_usage, stream_finish_reason, stream_response_id + nonlocal stream_response_model, stream_system_fingerprint + chunk_id = getattr(chunk, "id", None) + if chunk_id is not None: + stream_response_id = chunk_id + chunk_model = getattr(chunk, "model", None) + if chunk_model is not None: + stream_response_model = chunk_model + chunk_fp = getattr(chunk, "system_fingerprint", None) + if chunk_fp is not None: + stream_system_fingerprint = chunk_fp + choices = getattr(chunk, "choices", None) or [] + for choice in choices: + fr = getattr(choice, "finish_reason", None) + if fr is not None: + stream_finish_reason = fr + delta = getattr(choice, "delta", None) + if delta: + content = getattr(delta, "content", None) + if content: + accumulated_content.append(content) + tc_deltas = getattr(delta, "tool_calls", None) or [] + for tc_delta in tc_deltas: + idx = getattr(tc_delta, "index", 0) + if idx not in accumulated_tool_calls: + accumulated_tool_calls[idx] = { + "id": getattr(tc_delta, "id", None), + "name": "", + "arguments": "", + } + fn = getattr(tc_delta, "function", None) + if fn: + name = getattr(fn, "name", None) + if name: + accumulated_tool_calls[idx]["name"] = name + args = getattr(fn, "arguments", None) + if args: + accumulated_tool_calls[idx]["arguments"] += args + tc_id = getattr(tc_delta, "id", None) + if tc_id: + accumulated_tool_calls[idx]["id"] = tc_id + + usage = getattr(chunk, "usage", None) + if usage: + final_usage = adapter._extract_usage_from_obj(usage) + + def __enter__(self) -> Any: + return self + + def __exit__(self, *args: Any) -> Any: + if hasattr(self._inner, "__exit__"): + return self._inner.__exit__(*args) + if hasattr(self._inner, "close"): + self._inner.close() + return None + + def close(self) -> None: + if hasattr(self._inner, "close"): + self._inner.close() + + return TracedStream(stream) + + def _wrap_embeddings_create(self, original: Any) -> Any: + adapter = self + + def traced_embed(*args: Any, **kwargs: Any) -> Any: + model = kwargs.get("model") + start_ns = time.time_ns() + + try: + response = original(*args, **kwargs) + except Exception as exc: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + try: + adapter._emit_model_invoke( + provider="openai", + model=model, + latency_ms=elapsed_ms, + error=str(exc), + metadata={"request_type": "embedding"}, + ) + except Exception: + logger.warning("Error emitting OpenAI embedding error", exc_info=True) + raise + + try: + elapsed_ms = (time.time_ns() - start_ns) / 1_000_000 + usage = adapter._extract_usage(response) + adapter._emit_model_invoke( + provider="openai", + model=model, + usage=usage, + latency_ms=elapsed_ms, + metadata={"request_type": "embedding"}, + ) + adapter._emit_cost_record(model=model, usage=usage, provider="openai") + except Exception: + logger.warning("Error emitting OpenAI embedding events", exc_info=True) + + return response + + traced_embed._layerlens_original = original # type: ignore[attr-defined] + return traced_embed + + # --- Token extraction --- + + def _extract_usage(self, response: Any) -> Optional[NormalizedTokenUsage]: + """Extract token usage from a synchronous OpenAI response.""" + usage = getattr(response, "usage", None) + if not usage: + return None + return self._extract_usage_from_obj(usage) + + @staticmethod + def _extract_usage_from_obj(usage: Any) -> NormalizedTokenUsage: + """Extract :class:`NormalizedTokenUsage` from an OpenAI Usage object.""" + prompt = getattr(usage, "prompt_tokens", 0) or 0 + completion = getattr(usage, "completion_tokens", 0) or 0 + total = getattr(usage, "total_tokens", 0) or (prompt + completion) + + cached: Optional[int] = None + details = getattr(usage, "prompt_tokens_details", None) + if details: + cached = getattr(details, "cached_tokens", None) + + reasoning: Optional[int] = None + comp_details = getattr(usage, "completion_tokens_details", None) + if comp_details: + reasoning = getattr(comp_details, "reasoning_tokens", None) + + return NormalizedTokenUsage( + prompt_tokens=prompt, + completion_tokens=completion, + total_tokens=total, + cached_tokens=cached, + reasoning_tokens=reasoning, + ) + + @staticmethod + def _extract_output_message(response: Any) -> Optional[Dict[str, str]]: + """Extract the assistant output message from an OpenAI response.""" + try: + choices = getattr(response, "choices", None) or [] + if not choices: + return None + message = getattr(choices[0], "message", None) + if not message: + return None + content = getattr(message, "content", None) + if content: + return {"role": "assistant", "content": str(content)[:10_000]} + except Exception: + logger.debug("Error extracting OpenAI output message", exc_info=True) + return None + + @staticmethod + def _extract_tool_calls(response: Any) -> List[Dict[str, Any]]: + """Extract tool calls from an OpenAI response.""" + tool_calls: List[Dict[str, Any]] = [] + try: + choices = getattr(response, "choices", None) or [] + if not choices: + return tool_calls + message = getattr(choices[0], "message", None) + if not message: + return tool_calls + tcs = getattr(message, "tool_calls", None) or [] + for tc in tcs: + fn = getattr(tc, "function", None) + if fn: + args_str = getattr(fn, "arguments", "{}") + try: + args = json.loads(args_str) + except (json.JSONDecodeError, TypeError): + args = args_str + tool_calls.append( + { + "name": getattr(fn, "name", "unknown"), + "arguments": args, + "id": getattr(tc, "id", None), + } + ) + except Exception: + logger.debug("Error extracting OpenAI tool calls", exc_info=True) + return tool_calls + + +# Registry lazy-loading convention. +ADAPTER_CLASS = OpenAIAdapter diff --git a/tests/instrument/adapters/__init__.py b/tests/instrument/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/providers/__init__.py b/tests/instrument/adapters/providers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/adapters/providers/test_anthropic_adapter.py b/tests/instrument/adapters/providers/test_anthropic_adapter.py new file mode 100644 index 0000000..6c4f487 --- /dev/null +++ b/tests/instrument/adapters/providers/test_anthropic_adapter.py @@ -0,0 +1,385 @@ +"""Unit tests for the Anthropic provider adapter. + +Mocked at the SDK-response shape level. Verifies that the adapter wraps +``client.messages.create`` and ``client.messages.stream`` correctly, +emits the expected events, and restores the original methods on +disconnect. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List +from unittest import mock + +import pytest + +from layerlens.instrument.adapters._base import ( + AdapterStatus, + CaptureConfig, + AdapterCapability, +) +from layerlens.instrument.adapters.providers.anthropic_adapter import ( + ADAPTER_CLASS, + AnthropicAdapter, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + elif len(args) == 1: + self.events.append({"event_type": None, "payload": args[0]}) + + +def _make_response( + *, + text: str = "hello", + input_tokens: int = 10, + output_tokens: int = 5, + stop_reason: str = "end_turn", + response_id: str = "msg-abc", + response_model: str = "claude-sonnet-4-5-20250929", + tool_uses: List[Any] = None, + cache_creation: int = None, + cache_read: int = None, +) -> Any: + """Build an object that quacks like an Anthropic Message.""" + content_blocks = [SimpleNamespace(type="text", text=text)] + if tool_uses: + content_blocks.extend(tool_uses) + + usage_kwargs: Dict[str, Any] = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } + if cache_creation is not None: + usage_kwargs["cache_creation_input_tokens"] = cache_creation + if cache_read is not None: + usage_kwargs["cache_read_input_tokens"] = cache_read + + return SimpleNamespace( + id=response_id, + model=response_model, + content=content_blocks, + stop_reason=stop_reason, + usage=SimpleNamespace(**usage_kwargs), + ) + + +def _make_client(*, returns: Any = None, raises: Exception = None) -> Any: + def _create(**kwargs: Any) -> Any: + if raises is not None: + raise raises + return returns + + def _stream(**kwargs: Any) -> Any: + if raises is not None: + raise raises + return mock.MagicMock() + + messages = mock.MagicMock() + messages.create = _create + messages.stream = _stream + + return SimpleNamespace(messages=messages) + + +# --------------------------------------------------------------------------- +# Lifecycle + metadata +# --------------------------------------------------------------------------- + + +class TestAnthropicAdapterLifecycle: + def test_adapter_class_export(self) -> None: + assert ADAPTER_CLASS is AnthropicAdapter + + def test_framework_and_version(self) -> None: + adapter = AnthropicAdapter() + assert adapter.FRAMEWORK == "anthropic" + assert adapter.VERSION == "0.1.0" + + def test_connect_disconnect(self) -> None: + adapter = AnthropicAdapter() + adapter.connect() + assert adapter.is_connected is True + assert adapter.status == AdapterStatus.HEALTHY + adapter.disconnect() + assert adapter.is_connected is False + assert adapter.status == AdapterStatus.DISCONNECTED + + def test_get_adapter_info(self) -> None: + adapter = AnthropicAdapter() + info = adapter.get_adapter_info() + assert info.framework == "anthropic" + assert info.name == "AnthropicAdapter" + assert AdapterCapability.TRACE_MODELS in info.capabilities + + +# --------------------------------------------------------------------------- +# Wrapping messages.create +# --------------------------------------------------------------------------- + + +class TestAnthropicCreateWrap: + def test_connect_replaces_create_and_stream(self) -> None: + adapter = AnthropicAdapter() + client = _make_client(returns=_make_response()) + original_create = client.messages.create + original_stream = client.messages.stream + + adapter.connect_client(client) + + assert client.messages.create is not original_create + assert client.messages.stream is not original_stream + assert "messages.create" in adapter._originals + assert "messages.stream" in adapter._originals + + def test_successful_call_emits_event_set(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + messages=[{"role": "user", "content": "hi"}], + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["provider"] == "anthropic" + assert invoke["payload"]["model"] == "claude-sonnet-4-5-20250929" + assert invoke["payload"]["prompt_tokens"] == 10 + assert invoke["payload"]["completion_tokens"] == 5 + assert invoke["payload"]["total_tokens"] == 15 + assert invoke["payload"]["finish_reason"] == "end_turn" + assert invoke["payload"]["response_id"] == "msg-abc" + + def test_system_prompt_recorded_as_has_system(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + system="You are concise.", + messages=[{"role": "user", "content": "hi"}], + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["parameters"].get("has_system") is True + # The normalized messages include the system prompt as the first entry. + msgs = invoke["payload"].get("messages") + assert msgs is not None + assert msgs[0]["role"] == "system" + + def test_tools_count_captured(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + messages=[{"role": "user", "content": "x"}], + tools=[{"name": "calc"}, {"name": "search"}], + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["parameters"]["tools_count"] == 2 + + def test_tool_use_blocks_emit_tool_calls(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + tool_use = SimpleNamespace( + type="tool_use", + name="get_weather", + id="tool-1", + input={"city": "SF"}, + ) + client = _make_client(returns=_make_response(tool_uses=[tool_use])) + adapter.connect_client(client) + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + messages=[{"role": "user", "content": "weather"}], + ) + + tool_events = [e for e in stratix.events if e["event_type"] == "tool.call"] + assert len(tool_events) == 1 + assert tool_events[0]["payload"]["tool_name"] == "get_weather" + assert tool_events[0]["payload"]["tool_input"] == {"city": "SF"} + + def test_cache_metadata_captured(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client( + returns=_make_response(cache_creation=100, cache_read=200) + ) + adapter.connect_client(client) + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + messages=[{"role": "user", "content": "x"}], + ) + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["cache_creation_input_tokens"] == 100 + assert invoke["payload"]["cache_read_input_tokens"] == 200 + + def test_provider_error_emits_policy_violation(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(raises=RuntimeError("rate limited")) + adapter.connect_client(client) + + with pytest.raises(RuntimeError): + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + messages=[{"role": "user", "content": "x"}], + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "policy.violation" in types + + def test_disconnect_restores_originals(self) -> None: + adapter = AnthropicAdapter() + adapter.connect() + client = _make_client(returns=_make_response()) + original_create = client.messages.create + adapter.connect_client(client) + assert client.messages.create is not original_create + + adapter.disconnect() + assert client.messages.create is original_create + + +# --------------------------------------------------------------------------- +# Token extraction +# --------------------------------------------------------------------------- + + +class TestUsageExtraction: + def test_basic(self) -> None: + usage = SimpleNamespace( + input_tokens=100, + output_tokens=50, + ) + result = AnthropicAdapter._extract_usage_from_obj(usage) + assert result.prompt_tokens == 100 + assert result.completion_tokens == 50 + assert result.total_tokens == 150 + + def test_with_cache_read(self) -> None: + usage = SimpleNamespace( + input_tokens=100, + output_tokens=50, + cache_read_input_tokens=20, + ) + result = AnthropicAdapter._extract_usage_from_obj(usage) + assert result.cached_tokens == 20 + + def test_with_thinking_tokens(self) -> None: + usage = SimpleNamespace( + input_tokens=100, + output_tokens=50, + thinking_tokens=30, + ) + result = AnthropicAdapter._extract_usage_from_obj(usage) + assert result.reasoning_tokens == 30 + + +# --------------------------------------------------------------------------- +# Cost calculation +# --------------------------------------------------------------------------- + + +class TestCostCalculation: + def test_known_model_priced(self) -> None: + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client( + returns=_make_response(input_tokens=1000, output_tokens=500) + ) + adapter.connect_client(client) + client.messages.create( + model="claude-sonnet-4-5-20250929", + max_tokens=100, + messages=[{"role": "user", "content": "x"}], + ) + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + # claude-sonnet-4-5-20250929: 0.003 input + 0.015 output per 1k + # => 1000 * 0.003 / 1000 + 500 * 0.015 / 1000 = 0.003 + 0.0075 = 0.0105 + assert cost["payload"]["api_cost_usd"] == pytest.approx(0.0105, rel=1e-4) + + +# --------------------------------------------------------------------------- +# Stream event processing +# --------------------------------------------------------------------------- + + +class TestAnthropicStreaming: + def test_stream_emits_one_consolidated_invoke(self) -> None: + """Iterating a streamed response must emit exactly one model.invoke.""" + stratix = _RecordingStratix() + adapter = AnthropicAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + # Build a synthetic event stream that mirrors Anthropic's wire shape. + message_start = SimpleNamespace( + type="message_start", + message=SimpleNamespace( + id="msg-1", + model="claude-sonnet-4-5-20250929", + usage=SimpleNamespace(input_tokens=10, output_tokens=0), + ), + ) + block_delta = SimpleNamespace( + type="content_block_delta", + delta=SimpleNamespace(type="text_delta", text="hello"), + ) + message_delta = SimpleNamespace( + type="message_delta", + delta=SimpleNamespace(stop_reason="end_turn"), + usage=SimpleNamespace(output_tokens=5), + ) + + events = iter([message_start, block_delta, message_delta]) + + # We bypass connect_client and exercise _wrap_stream_response directly + # because the streaming public API uses a context manager. + wrapped = adapter._wrap_stream_response( + events, model="claude-sonnet-4-5-20250929", params={}, start_ns=0 + ) + for _ in wrapped: + pass + + invokes = [e for e in stratix.events if e["event_type"] == "model.invoke"] + assert len(invokes) == 1 + payload = invokes[0]["payload"] + assert payload.get("streaming") is True + assert payload["finish_reason"] == "end_turn" + # We accumulate the text content into the output_message. + assert payload.get("output_message") is not None diff --git a/tests/instrument/adapters/providers/test_anthropic_adapter_live.py b/tests/instrument/adapters/providers/test_anthropic_adapter_live.py new file mode 100644 index 0000000..770572c --- /dev/null +++ b/tests/instrument/adapters/providers/test_anthropic_adapter_live.py @@ -0,0 +1,144 @@ +"""Live Anthropic integration tests for ``AnthropicAdapter``. + +Gated by ``@pytest.mark.live`` AND the presence of ``ANTHROPIC_API_KEY``. +These make REAL calls and incur small cost (single-token completions). + +Same testing strategy as the OpenAI live tests: a real Anthropic call +flows through the adapter into a real ``HttpEventSink`` pointed at a +localhost ingest server that mirrors atlas-app's wire contract. +Structural-invariant assertions only. +""" + +from __future__ import annotations + +import os +import json +import time +import threading +from typing import Any, Dict, List, Tuple +from http.server import HTTPServer, BaseHTTPRequestHandler + +import pytest + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.providers.anthropic_adapter import AnthropicAdapter + +pytestmark = [ + pytest.mark.live, + pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY"), + reason="ANTHROPIC_API_KEY not set; skipping live Anthropic tests", + ), +] + + +@pytest.fixture +def live_anthropic_client() -> Any: + try: + from anthropic import Anthropic + except ImportError: + pytest.skip("anthropic package not installed") + return Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) + + +class _IngestRecorder: + def __init__(self) -> None: + self.batches: List[Dict[str, Any]] = [] + self.lock = threading.Lock() + + +def _make_ingest_handler(recorder: _IngestRecorder) -> type: + class _Handler(BaseHTTPRequestHandler): + def log_message(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + pass + + def do_POST(self) -> None: # noqa: N802 + length = int(self.headers.get("Content-Length", "0")) + raw = self.rfile.read(length) if length > 0 else b"" + try: + body = json.loads(raw) + except json.JSONDecodeError: + body = {"_raw": raw.decode("utf-8", "replace")} + with recorder.lock: + recorder.batches.append(body) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"ok":true}') + + return _Handler + + +@pytest.fixture +def ingest_server() -> Any: + recorder = _IngestRecorder() + httpd = HTTPServer(("127.0.0.1", 0), _make_ingest_handler(recorder)) + port = httpd.server_address[1] + thread = threading.Thread(target=httpd.serve_forever, daemon=True) + thread.start() + try: + yield f"http://127.0.0.1:{port}", recorder + finally: + httpd.shutdown() + thread.join(timeout=5.0) + httpd.server_close() + + +class TestAnthropicAdapterLive: + def test_real_messages_create_emits_full_event_set( + self, + live_anthropic_client: Any, + ingest_server: Tuple[str, _IngestRecorder], + ) -> None: + base_url, recorder = ingest_server + + sink = HttpEventSink( + adapter_name="anthropic", + api_key="test-org-key", + base_url=base_url, + path="/telemetry/spans", + max_batch=1, + flush_interval_s=0.0, + ) + + adapter = AnthropicAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + adapter.connect_client(live_anthropic_client) + + try: + response = live_anthropic_client.messages.create( + model="claude-haiku-4-5-20251001", + max_tokens=10, + messages=[{"role": "user", "content": "Say hi in one word."}], + ) + finally: + sink.close() + adapter.disconnect() + + assert response.content + assert response.usage is not None + + time.sleep(0.5) + with recorder.lock: + batches = list(recorder.batches) + assert batches, "no events reached the ingest server" + + all_events: List[Dict[str, Any]] = [] + for batch in batches: + all_events.extend(batch.get("events", [])) + + types = [e["event_type"] for e in all_events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in all_events if e["event_type"] == "model.invoke") + # Real provider field — would FAIL if Anthropic SDK renamed `usage.input_tokens`. + assert invoke["payload"]["prompt_tokens"] == response.usage.input_tokens + assert invoke["payload"]["completion_tokens"] == response.usage.output_tokens + assert invoke["payload"]["latency_ms"] > 0 + + cost = next(e for e in all_events if e["event_type"] == "cost.record") + assert cost["payload"]["api_cost_usd"] is not None + assert cost["payload"]["api_cost_usd"] >= 0 diff --git a/tests/instrument/adapters/providers/test_azure_openai_adapter.py b/tests/instrument/adapters/providers/test_azure_openai_adapter.py new file mode 100644 index 0000000..919d3f8 --- /dev/null +++ b/tests/instrument/adapters/providers/test_azure_openai_adapter.py @@ -0,0 +1,137 @@ +"""Unit tests for the Azure OpenAI provider adapter.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List +from unittest import mock + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.azure_openai_adapter import ( + ADAPTER_CLASS, + AzureOpenAIAdapter, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def _make_response() -> Any: + message = SimpleNamespace(role="assistant", content="hello", tool_calls=None) + choice = SimpleNamespace(message=message, finish_reason="stop", index=0) + usage = SimpleNamespace( + prompt_tokens=10, + completion_tokens=5, + total_tokens=15, + prompt_tokens_details=None, + completion_tokens_details=None, + ) + return SimpleNamespace( + id="chatcmpl-azure", + model="gpt-4o", + choices=[choice], + usage=usage, + system_fingerprint="fp-az", + ) + + +def _make_client(*, returns: Any = None) -> Any: + completions = mock.MagicMock() + completions.create = lambda **kw: returns + + embeddings = mock.MagicMock() + embeddings.create = lambda **kw: SimpleNamespace( + usage=SimpleNamespace( + prompt_tokens=8, + completion_tokens=0, + total_tokens=8, + prompt_tokens_details=None, + completion_tokens_details=None, + ), + ) + + chat = SimpleNamespace(completions=completions) + return SimpleNamespace( + chat=chat, + embeddings=embeddings, + _base_url="https://my-resource.openai.azure.com/?api-key=secret", + _api_version="2024-08-01", + ) + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is AzureOpenAIAdapter + + +def test_lifecycle() -> None: + a = AzureOpenAIAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + a.disconnect() + assert a.status == AdapterStatus.DISCONNECTED + + +def test_endpoint_sanitization_strips_query_string() -> None: + """Token leakage prevention: query string is removed from azure_endpoint metadata.""" + raw = "https://my-resource.openai.azure.com/path/?api-key=SECRET" + sanitized = AzureOpenAIAdapter._sanitize_endpoint(raw) + assert sanitized is not None + assert "SECRET" not in sanitized + assert "my-resource.openai.azure.com" in sanitized + + +def test_uses_azure_pricing() -> None: + """Azure adapter must compute cost from AZURE_PRICING (different rates than OpenAI).""" + stratix = _RecordingStratix() + adapter = AzureOpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + # AZURE_PRICING for gpt-4o: 0.00275 input + 0.011 output per 1k. + # 10 prompt + 5 completion = 10 * 0.00275 / 1000 + 5 * 0.011 / 1000 + # = 0.0000275 + 0.000055 = 0.0000825 + assert cost["payload"]["api_cost_usd"] is not None + expected = 10 * 0.00275 / 1000 + 5 * 0.011 / 1000 + assert abs(cost["payload"]["api_cost_usd"] - expected) < 1e-6 + + +def test_azure_metadata_in_payload() -> None: + stratix = _RecordingStratix() + adapter = AzureOpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["api_version"] == "2024-08-01" + # Endpoint has no query string after sanitization. + assert "api-key" not in invoke["payload"]["azure_endpoint"] + + +def test_disconnect_restores_originals() -> None: + adapter = AzureOpenAIAdapter() + adapter.connect() + client = _make_client(returns=_make_response()) + original = client.chat.completions.create + adapter.connect_client(client) + assert client.chat.completions.create is not original + adapter.disconnect() + assert client.chat.completions.create is original diff --git a/tests/instrument/adapters/providers/test_bedrock_adapter.py b/tests/instrument/adapters/providers/test_bedrock_adapter.py new file mode 100644 index 0000000..d87d576 --- /dev/null +++ b/tests/instrument/adapters/providers/test_bedrock_adapter.py @@ -0,0 +1,152 @@ +"""Unit tests for the AWS Bedrock provider adapter.""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List + +import pytest + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.bedrock_adapter import ( + ADAPTER_CLASS, + AWSBedrockAdapter, + _detect_provider_family, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +@pytest.mark.parametrize( + "model_id,family", + [ + ("anthropic.claude-3-5-sonnet-20241022-v2:0", "anthropic"), + ("meta.llama3-1-70b-instruct-v1:0", "meta"), + ("cohere.command-r-v1:0", "cohere"), + ("amazon.titan-text-express-v1", "amazon"), + ("ai21.jamba-instruct-v1:0", "ai21"), + ("mistral.mistral-7b-instruct-v0:2", "mistral"), + ("unknown.model-v1", "unknown"), + ("", "unknown"), + ], +) +def test_detect_provider_family(model_id: str, family: str) -> None: + assert _detect_provider_family(model_id) == family + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is AWSBedrockAdapter + + +def test_lifecycle() -> None: + a = AWSBedrockAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + + +def test_extract_invoke_usage_anthropic() -> None: + body = {"usage": {"input_tokens": 100, "output_tokens": 50}} + usage = AWSBedrockAdapter._extract_invoke_usage(body, "anthropic") + assert usage is not None + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + + +def test_extract_invoke_usage_meta() -> None: + body = {"prompt_token_count": 50, "generation_token_count": 25} + usage = AWSBedrockAdapter._extract_invoke_usage(body, "meta") + assert usage is not None + assert usage.prompt_tokens == 50 + assert usage.completion_tokens == 25 + + +def test_extract_invoke_usage_cohere() -> None: + body = {"meta": {"billed_units": {"input_tokens": 10, "output_tokens": 5}}} + usage = AWSBedrockAdapter._extract_invoke_usage(body, "cohere") + assert usage is not None + assert usage.prompt_tokens == 10 + assert usage.completion_tokens == 5 + + +def test_extract_converse_usage() -> None: + response = {"usage": {"inputTokens": 100, "outputTokens": 50}} + usage = AWSBedrockAdapter._extract_converse_usage(response) + assert usage is not None + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + + +def test_extract_anthropic_invoke_messages() -> None: + body = json.dumps( + { + "system": "You are helpful.", + "messages": [{"role": "user", "content": "Hi"}], + } + ) + msgs = AWSBedrockAdapter._extract_invoke_messages( + {"body": body}, + "anthropic.claude-3-5-sonnet-20241022-v2:0", + ) + assert msgs is not None + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + + +def test_rereadable_body_can_be_read_twice() -> None: + """The wrapper around StreamingBody must support re-reading after we consume it.""" + from layerlens.instrument.adapters.providers.bedrock_adapter import _RereadableBody + + body = _RereadableBody(b'{"hello":"world"}') + assert body.read() == b'{"hello":"world"}' + # Caller code reads again — must still get the data. + assert body.read() == b'{"hello":"world"}' + + +def test_converse_emits_full_event_set() -> None: + """Converse API call must emit model.invoke + cost.record.""" + + class _FakeClient: + def converse(self, **kwargs: Any) -> Dict[str, Any]: + return { + "output": { + "message": { + "role": "assistant", + "content": [{"text": "hello"}], + } + }, + "usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15}, + "stopReason": "end_turn", + "ResponseMetadata": {"RequestId": "req-abc"}, + } + + stratix = _RecordingStratix() + adapter = AWSBedrockAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _FakeClient() + adapter.connect_client(client) + + client.converse( + modelId="anthropic.claude-3-5-sonnet-20241022-v2:0", + messages=[{"role": "user", "content": [{"text": "hi"}]}], + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["finish_reason"] == "end_turn" + assert invoke["payload"]["response_id"] == "req-abc" + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + # claude-3-5-sonnet pricing in BEDROCK_PRICING: 0.003 input, 0.015 output per 1k. + expected = 10 * 0.003 / 1000 + 5 * 0.015 / 1000 + assert abs(cost["payload"]["api_cost_usd"] - expected) < 1e-6 diff --git a/tests/instrument/adapters/providers/test_cohere_adapter.py b/tests/instrument/adapters/providers/test_cohere_adapter.py new file mode 100644 index 0000000..c101a5a --- /dev/null +++ b/tests/instrument/adapters/providers/test_cohere_adapter.py @@ -0,0 +1,241 @@ +"""Unit tests for the Cohere provider adapter.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.cohere_adapter import ( + ADAPTER_CLASS, + CohereAdapter, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def _make_v1_response( + text: str = "hello", + input_tokens: int = 10, + output_tokens: int = 5, + response_id: str = "gen-abc", + finish_reason: str = "COMPLETE", + tool_calls: List[Any] = None, +) -> Any: + """Build a v1 Cohere chat response.""" + billed = SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens) + meta = SimpleNamespace(billed_units=billed) + return SimpleNamespace( + text=text, + generation_id=response_id, + meta=meta, + finish_reason=finish_reason, + tool_calls=tool_calls, + ) + + +def _make_v2_response( + text: str = "hello", + input_tokens: int = 10, + output_tokens: int = 5, +) -> Any: + """Build a v2 Cohere chat response.""" + text_block = SimpleNamespace(type="text", text=text) + message = SimpleNamespace(content=[text_block], tool_calls=None) + billed = SimpleNamespace(input_tokens=input_tokens, output_tokens=output_tokens) + meta = SimpleNamespace(billed_units=billed) + return SimpleNamespace( + id="msg-xyz", + message=message, + meta=meta, + finish_reason="COMPLETE", + ) + + +def _make_client(*, returns_v1: Any = None, returns_v2: Any = None) -> Any: + def chat(**kwargs: Any) -> Any: + return returns_v1 + + def v2_chat(**kwargs: Any) -> Any: + return returns_v2 + + def embed(**kwargs: Any) -> Any: + return SimpleNamespace( + embeddings=[[0.1, 0.2]], + meta=SimpleNamespace(billed_units=SimpleNamespace(input_tokens=4)), + ) + + v2 = SimpleNamespace(chat=v2_chat) + return SimpleNamespace(chat=chat, v2=v2, embed=embed) + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is CohereAdapter + + +def test_lifecycle() -> None: + a = CohereAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + a.disconnect() + assert a.status == AdapterStatus.DISCONNECTED + + +def test_v1_chat_emits_invoke_and_cost() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns_v1=_make_v1_response()) + adapter.connect_client(client) + + client.chat(model="command-r-plus", message="hi") + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["provider"] == "cohere" + assert invoke["payload"]["model"] == "command-r-plus" + assert invoke["payload"]["prompt_tokens"] == 10 + assert invoke["payload"]["completion_tokens"] == 5 + assert invoke["payload"]["parameters"]["api_version"] == "v1" + assert invoke["payload"]["finish_reason"] == "COMPLETE" + + +def test_v1_chat_input_message_captured() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(returns_v1=_make_v1_response()) + adapter.connect_client(client) + + client.chat( + model="command-r", + message="hello world", + preamble="You are a helpful assistant.", + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + msgs = invoke["payload"].get("messages") + assert msgs is not None + # Preamble inserted as system message at position 0. + assert msgs[0]["role"] == "system" + assert "helpful" in msgs[0]["content"] + assert msgs[1]["role"] == "user" + assert "hello" in msgs[1]["content"] + + +def test_v1_output_text_captured() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(returns_v1=_make_v1_response(text="answer-text")) + adapter.connect_client(client) + client.chat(model="command-r", message="x") + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + out = invoke["payload"].get("output_message") + assert out is not None + assert out["content"] == "answer-text" + + +def test_v1_tool_calls_emitted() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + tc = SimpleNamespace(name="lookup", parameters={"q": "weather"}) + client = _make_client(returns_v1=_make_v1_response(tool_calls=[tc])) + adapter.connect_client(client) + client.chat(model="command-r", message="x") + + tool_events = [e for e in stratix.events if e["event_type"] == "tool.call"] + assert len(tool_events) == 1 + assert tool_events[0]["payload"]["tool_name"] == "lookup" + + +def test_v2_chat_emits_invoke() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns_v2=_make_v2_response()) + adapter.connect_client(client) + + client.v2.chat(model="command-r-plus", messages=[{"role": "user", "content": "hi"}]) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["parameters"]["api_version"] == "v2" + assert invoke["payload"]["output_message"]["content"] == "hello" + + +def test_provider_error_emits_policy_violation() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + def bad_chat(**kwargs: Any) -> Any: + raise RuntimeError("rate limited") + + client = SimpleNamespace(chat=bad_chat, v2=None, embed=None) + adapter.connect_client(client) + + with pytest.raises(RuntimeError): + client.chat(model="command-r", message="x") + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "policy.violation" in types + + +def test_disconnect_restores_originals() -> None: + adapter = CohereAdapter() + adapter.connect() + + client = _make_client(returns_v1=_make_v1_response()) + original_chat = client.chat + adapter.connect_client(client) + assert client.chat is not original_chat + adapter.disconnect() + assert client.chat is original_chat + + +def test_known_model_priced() -> None: + """``command-r-plus`` is in the canonical PRICING table.""" + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client( + returns_v1=_make_v1_response(input_tokens=1000, output_tokens=500), + ) + adapter.connect_client(client) + client.chat(model="command-r-plus", message="x") + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + # command-r-plus: 0.003 input + 0.015 output per 1k. + expected = 1000 * 0.003 / 1000 + 500 * 0.015 / 1000 + assert cost["payload"]["api_cost_usd"] == pytest.approx(expected, rel=1e-4) + + +def test_embed_emits_events() -> None: + stratix = _RecordingStratix() + adapter = CohereAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client() + adapter.connect_client(client) + + client.embed(model="embed-english-v3.0", texts=["hi"]) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["request_type"] == "embedding" diff --git a/tests/instrument/adapters/providers/test_litellm_adapter.py b/tests/instrument/adapters/providers/test_litellm_adapter.py new file mode 100644 index 0000000..fb8b2b4 --- /dev/null +++ b/tests/instrument/adapters/providers/test_litellm_adapter.py @@ -0,0 +1,188 @@ +"""Unit tests for the LiteLLM provider adapter.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List +from datetime import datetime, timezone + +import pytest + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.litellm_adapter import ( + ADAPTER_CLASS, + LiteLLMAdapter, + LayerLensLiteLLMCallback, + detect_provider, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +# --------------------------------------------------------------------------- +# detect_provider +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model,expected", + [ + ("openai/gpt-4o", "openai"), + ("anthropic/claude-sonnet", "anthropic"), + ("azure/my-deployment", "azure_openai"), + ("bedrock/anthropic.claude-3-5-sonnet", "aws_bedrock"), + ("vertex_ai/gemini-1.5-pro", "google_vertex"), + ("ollama/llama3", "ollama"), + ("cohere/command-r", "cohere"), + ("groq/llama3-70b", "groq"), + ("gpt-4o", "openai"), + ("o1-mini", "openai"), + ("claude-3-5-sonnet", "anthropic"), + ("gemini-2.0-flash", "google_vertex"), + ("llama-3.1-70b", "meta"), + ("mistral-large", "mistral"), + ("totally-unknown-model", "unknown"), + ("", "unknown"), + ], +) +def test_detect_provider_table(model: str, expected: str) -> None: + assert detect_provider(model) == expected + + +# --------------------------------------------------------------------------- +# Adapter lifecycle +# --------------------------------------------------------------------------- + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is LiteLLMAdapter + + +def test_backward_compat_alias() -> None: + """STRATIX* alias preserved for users coming from ateam.""" + from layerlens.instrument.adapters.providers.litellm_adapter import ( + STRATIXLiteLLMCallback, + ) + + assert STRATIXLiteLLMCallback is LayerLensLiteLLMCallback + + +def test_connect_registers_callback_with_litellm() -> None: + import litellm # type: ignore[import-not-found,unused-ignore] + + adapter = LiteLLMAdapter() + try: + adapter.connect() + assert adapter.status in (AdapterStatus.HEALTHY, AdapterStatus.DEGRADED) + if adapter.status == AdapterStatus.HEALTHY: + assert adapter._callback in litellm.callbacks + finally: + adapter.disconnect() + + +def test_disconnect_removes_callback() -> None: + import litellm # type: ignore[import-not-found,unused-ignore] + + adapter = LiteLLMAdapter() + adapter.connect() + cb = adapter._callback + if cb is not None and adapter.status == AdapterStatus.HEALTHY: + assert cb in litellm.callbacks + adapter.disconnect() + if cb is not None: + assert cb not in litellm.callbacks + assert adapter.status == AdapterStatus.DISCONNECTED + + +# --------------------------------------------------------------------------- +# Callback handlers +# --------------------------------------------------------------------------- + + +def _make_response_obj() -> Any: + message = SimpleNamespace(role="assistant", content="hello", tool_calls=None) + choice = SimpleNamespace(message=message, finish_reason="stop", index=0) + usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return SimpleNamespace( + id="chatcmpl-x", + model="gpt-4o", + choices=[choice], + usage=usage, + ) + + +def test_log_success_event_emits_invoke_and_cost() -> None: + stratix = _RecordingStratix() + adapter = LiteLLMAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + cb = LayerLensLiteLLMCallback(adapter) + + start = datetime(2026, 1, 1, tzinfo=timezone.utc) + end = datetime(2026, 1, 1, 0, 0, 1, tzinfo=timezone.utc) + + cb.log_success_event( + kwargs={ + "model": "openai/gpt-4o", + "messages": [{"role": "user", "content": "hi"}], + "temperature": 0.5, + }, + response_obj=_make_response_obj(), + start_time=start, + end_time=end, + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["provider"] == "openai" + # Latency ~ 1s = 1000 ms. + assert 900 < invoke["payload"]["latency_ms"] < 1100 + + +def test_log_failure_event_emits_policy_violation() -> None: + stratix = _RecordingStratix() + adapter = LiteLLMAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + cb = LayerLensLiteLLMCallback(adapter) + + cb.log_failure_event( + kwargs={ + "model": "anthropic/claude-sonnet", + "messages": [{"role": "user", "content": "x"}], + "exception": "rate limited", + }, + response_obj=None, + start_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + end_time=datetime(2026, 1, 1, 0, 0, 1, tzinfo=timezone.utc), + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "policy.violation" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["error"] == "rate limited" + assert invoke["payload"]["provider"] == "anthropic" + + +def test_log_stream_event_marks_streaming() -> None: + stratix = _RecordingStratix() + adapter = LiteLLMAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + cb = LayerLensLiteLLMCallback(adapter) + + cb.log_stream_event( + kwargs={"model": "openai/gpt-4o-mini"}, + response_obj=_make_response_obj(), + start_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + end_time=datetime(2026, 1, 1, 0, 0, 1, tzinfo=timezone.utc), + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["streaming"] is True diff --git a/tests/instrument/adapters/providers/test_mistral_adapter.py b/tests/instrument/adapters/providers/test_mistral_adapter.py new file mode 100644 index 0000000..cbbcf30 --- /dev/null +++ b/tests/instrument/adapters/providers/test_mistral_adapter.py @@ -0,0 +1,267 @@ +"""Unit tests for the Mistral provider adapter.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.mistral_adapter import ( + ADAPTER_CLASS, + MistralAdapter, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def _make_response( + content: str = "hello", + prompt_tokens: int = 10, + completion_tokens: int = 5, + response_id: str = "msg-abc", + finish_reason: str = "stop", + tool_calls: List[Any] = None, +) -> Any: + """Build an OpenAI-shape Mistral response.""" + message = SimpleNamespace(role="assistant", content=content, tool_calls=tool_calls) + choice = SimpleNamespace(message=message, finish_reason=finish_reason, index=0) + usage = SimpleNamespace( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + return SimpleNamespace( + id=response_id, + model="mistral-small-latest", + choices=[choice], + usage=usage, + ) + + +def _make_client(*, returns: Any = None, raises: Exception = None) -> Any: + def complete(**kwargs: Any) -> Any: + if raises is not None: + raise raises + return returns + + def stream(**kwargs: Any) -> Any: + if raises is not None: + raise raises + return iter([]) + + def embed_create(**kwargs: Any) -> Any: + return SimpleNamespace( + data=[SimpleNamespace(embedding=[0.1, 0.2])], + usage=SimpleNamespace( + prompt_tokens=4, completion_tokens=0, total_tokens=4 + ), + ) + + chat = SimpleNamespace(complete=complete, stream=stream) + embeddings = SimpleNamespace(create=embed_create) + return SimpleNamespace(chat=chat, embeddings=embeddings) + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is MistralAdapter + + +def test_lifecycle() -> None: + a = MistralAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + a.disconnect() + assert a.status == AdapterStatus.DISCONNECTED + + +def test_complete_emits_invoke_and_cost() -> None: + stratix = _RecordingStratix() + adapter = MistralAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "hi"}], + temperature=0.5, + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["provider"] == "mistral" + assert invoke["payload"]["model"] == "mistral-small-latest" + assert invoke["payload"]["prompt_tokens"] == 10 + assert invoke["payload"]["completion_tokens"] == 5 + assert invoke["payload"]["parameters"]["temperature"] == 0.5 + assert invoke["payload"]["finish_reason"] == "stop" + + +def test_known_model_priced() -> None: + stratix = _RecordingStratix() + adapter = MistralAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client( + returns=_make_response(prompt_tokens=1000, completion_tokens=500), + ) + adapter.connect_client(client) + client.chat.complete( + model="mistral-small-latest", + messages=[{"role": "user", "content": "x"}], + ) + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + # mistral-small-latest: 0.0002 input + 0.0006 output per 1k. + expected = 1000 * 0.0002 / 1000 + 500 * 0.0006 / 1000 + assert cost["payload"]["api_cost_usd"] == pytest.approx(expected, rel=1e-4) + + +def test_provider_error_emits_policy_violation() -> None: + stratix = _RecordingStratix() + adapter = MistralAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client(raises=RuntimeError("rate limited")) + adapter.connect_client(client) + + with pytest.raises(RuntimeError): + client.chat.complete( + model="mistral-large-latest", + messages=[{"role": "user", "content": "x"}], + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "policy.violation" in types + + +def test_tool_calls_extracted() -> None: + stratix = _RecordingStratix() + adapter = MistralAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + fn = SimpleNamespace(name="get_time", arguments='{"tz": "UTC"}') + tc = SimpleNamespace(id="call-1", function=fn) + + client = _make_client(returns=_make_response(tool_calls=[tc])) + adapter.connect_client(client) + client.chat.complete( + model="mistral-large-latest", + messages=[{"role": "user", "content": "what time"}], + ) + + tool_events = [e for e in stratix.events if e["event_type"] == "tool.call"] + assert len(tool_events) == 1 + assert tool_events[0]["payload"]["tool_name"] == "get_time" + assert tool_events[0]["payload"]["tool_input"] == {"tz": "UTC"} + + +def test_disconnect_restores_originals() -> None: + adapter = MistralAdapter() + adapter.connect() + client = _make_client(returns=_make_response()) + original_complete = client.chat.complete + adapter.connect_client(client) + assert client.chat.complete is not original_complete + adapter.disconnect() + assert client.chat.complete is original_complete + + +def test_streaming_emits_consolidated_event() -> None: + """Iterating the stream emits exactly one consolidated model.invoke.""" + stratix = _RecordingStratix() + adapter = MistralAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + # Build a synthetic stream of CompletionEvent-like objects. + def stream_events(**kwargs: Any) -> Any: + return iter( + [ + SimpleNamespace( + data=SimpleNamespace( + id="msg-1", + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content="Hello "), + finish_reason=None, + ) + ], + usage=None, + ) + ), + SimpleNamespace( + data=SimpleNamespace( + id="msg-1", + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content="world"), + finish_reason=None, + ) + ], + usage=None, + ) + ), + SimpleNamespace( + data=SimpleNamespace( + id="msg-1", + choices=[ + SimpleNamespace( + delta=SimpleNamespace(content=None), + finish_reason="stop", + ) + ], + usage=SimpleNamespace( + prompt_tokens=8, + completion_tokens=2, + total_tokens=10, + ), + ) + ), + ] + ) + + client = SimpleNamespace( + chat=SimpleNamespace(complete=lambda **kw: None, stream=stream_events), + embeddings=None, + ) + adapter.connect_client(client) + + stream = client.chat.stream( + model="mistral-small-latest", + messages=[{"role": "user", "content": "hi"}], + ) + chunks = list(stream) + assert len(chunks) == 3 + + invokes = [e for e in stratix.events if e["event_type"] == "model.invoke"] + assert len(invokes) == 1 + assert invokes[0]["payload"]["streaming"] is True + assert invokes[0]["payload"]["finish_reason"] == "stop" + assert invokes[0]["payload"]["output_message"]["content"] == "Hello world" + + +def test_embed_emits_events() -> None: + stratix = _RecordingStratix() + adapter = MistralAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client() + adapter.connect_client(client) + + client.embeddings.create(model="mistral-embed", inputs=["hi"]) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["request_type"] == "embedding" diff --git a/tests/instrument/adapters/providers/test_ollama_adapter.py b/tests/instrument/adapters/providers/test_ollama_adapter.py new file mode 100644 index 0000000..e314541 --- /dev/null +++ b/tests/instrument/adapters/providers/test_ollama_adapter.py @@ -0,0 +1,121 @@ +"""Unit tests for the Ollama provider adapter.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.ollama_adapter import ( + ADAPTER_CLASS, + OllamaAdapter, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def _make_chat_response() -> Dict[str, Any]: + return { + "message": {"role": "assistant", "content": "hello"}, + "prompt_eval_count": 10, + "eval_count": 5, + "prompt_eval_duration": 1_000_000_000, # 1s + "eval_duration": 2_000_000_000, # 2s + "done_reason": "stop", + } + + +def _make_client(*, chat_response: Any = None) -> Any: + client = SimpleNamespace() + client.chat = lambda **kw: chat_response or _make_chat_response() + client.generate = lambda **kw: { + "response": "hi", + "prompt_eval_count": 5, + "eval_count": 2, + } + client.embeddings = lambda **kw: {"embedding": [0.1, 0.2]} + return client + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is OllamaAdapter + + +def test_connect_uses_env_endpoint(monkeypatch: Any) -> None: + monkeypatch.setenv("OLLAMA_HOST", "http://my-ollama:11434") + a = OllamaAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + assert a._endpoint == "http://my-ollama:11434" + + +def test_connect_default_endpoint(monkeypatch: Any) -> None: + monkeypatch.delenv("OLLAMA_HOST", raising=False) + a = OllamaAdapter() + a.connect() + assert a._endpoint == "http://localhost:11434" + + +def test_chat_emits_zero_api_cost() -> None: + """Local inference => api_cost_usd is exactly 0.0.""" + stratix = _RecordingStratix() + adapter = OllamaAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client() + adapter.connect_client(client) + + client.chat(model="llama3.1", messages=[{"role": "user", "content": "hi"}]) + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + assert cost["payload"]["api_cost_usd"] == 0.0 + + +def test_infra_cost_calculated_when_configured() -> None: + stratix = _RecordingStratix() + adapter = OllamaAdapter( + stratix=stratix, + capture_config=CaptureConfig.full(), + cost_per_second=0.01, + ) + adapter.connect() + client = _make_client() + adapter.connect_client(client) + + client.chat(model="llama3.1", messages=[{"role": "user", "content": "hi"}]) + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + # 1s + 2s = 3s @ $0.01/s = $0.03 + assert cost["payload"]["infra_cost_usd"] == 0.03 + + +def test_generate_method_works() -> None: + stratix = _RecordingStratix() + adapter = OllamaAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_client() + adapter.connect_client(client) + + client.generate(model="llama3.1", prompt="Why is the sky blue?") + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["method"] == "generate" + # Generate response captured as output_message. + assert invoke["payload"].get("output_message") is not None + + +def test_disconnect_restores_originals() -> None: + adapter = OllamaAdapter() + adapter.connect() + client = _make_client() + original_chat = client.chat + adapter.connect_client(client) + assert client.chat is not original_chat + adapter.disconnect() + assert client.chat is original_chat diff --git a/tests/instrument/adapters/providers/test_openai_adapter.py b/tests/instrument/adapters/providers/test_openai_adapter.py new file mode 100644 index 0000000..c29def9 --- /dev/null +++ b/tests/instrument/adapters/providers/test_openai_adapter.py @@ -0,0 +1,537 @@ +"""Unit tests for the OpenAI provider adapter. + +Tests are mocked — no real OpenAI API is contacted. Verifies that the +adapter wraps ``client.chat.completions.create`` and +``client.embeddings.create`` correctly, emits the expected events, and +restores the original methods on disconnect. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List +from unittest import mock + +import pytest + +from layerlens.instrument.adapters._base import ( + EventSink, + AdapterStatus, + CaptureConfig, + AdapterCapability, +) +from layerlens.instrument.adapters.providers._base.tokens import NormalizedTokenUsage +from layerlens.instrument.adapters.providers.openai_adapter import ( + ADAPTER_CLASS, + OpenAIAdapter, +) + + +class _RecordingStratix: + """Captures every event the adapter emits for assertion.""" + + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + # The adapter calls emit_dict_event which calls + # _stratix.emit(event_type, payload). Capture both forms. + if len(args) == 2 and isinstance(args[0], str): + event_type, payload = args + self.events.append({"event_type": event_type, "payload": payload}) + elif len(args) == 1: + self.events.append({"event_type": None, "payload": args[0]}) + + +def _make_response( + *, + content: str = "hello", + prompt_tokens: int = 10, + completion_tokens: int = 5, + total_tokens: int = 15, + finish_reason: str = "stop", + response_id: str = "chatcmpl-abc", + response_model: str = "gpt-4o", + tool_calls: List[Any] = None, +) -> Any: + """Build an object that quacks like an OpenAI ChatCompletion.""" + message = SimpleNamespace( + role="assistant", + content=content, + tool_calls=tool_calls or None, + ) + choice = SimpleNamespace( + message=message, + finish_reason=finish_reason, + index=0, + ) + usage = SimpleNamespace( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + prompt_tokens_details=None, + completion_tokens_details=None, + ) + return SimpleNamespace( + id=response_id, + model=response_model, + choices=[choice], + usage=usage, + system_fingerprint="fp-xyz", + service_tier="default", + ) + + +def _make_client(*, returns: Any = None, raises: Exception = None) -> Any: + """Build an object that quacks like an OpenAI Client.""" + + def _create(**kwargs: Any) -> Any: + if raises is not None: + raise raises + return returns + + def _embed(**kwargs: Any) -> Any: + if raises is not None: + raise raises + return SimpleNamespace( + data=[SimpleNamespace(embedding=[0.1, 0.2, 0.3])], + model=kwargs.get("model"), + usage=SimpleNamespace( + prompt_tokens=8, + completion_tokens=0, + total_tokens=8, + prompt_tokens_details=None, + completion_tokens_details=None, + ), + ) + + completions = mock.MagicMock() + completions.create = _create + chat = SimpleNamespace(completions=completions) + + embeddings = mock.MagicMock() + embeddings.create = _embed + + return SimpleNamespace(chat=chat, embeddings=embeddings) + + +# --------------------------------------------------------------------------- +# Lifecycle + metadata +# --------------------------------------------------------------------------- + + +class TestOpenAIAdapterLifecycle: + def test_adapter_class_export(self) -> None: + """Registry uses the ``ADAPTER_CLASS`` convention.""" + assert ADAPTER_CLASS is OpenAIAdapter + + def test_framework_and_version(self) -> None: + adapter = OpenAIAdapter() + assert adapter.FRAMEWORK == "openai" + assert adapter.VERSION == "0.1.0" + + def test_connect_and_disconnect(self) -> None: + adapter = OpenAIAdapter() + adapter.connect() + assert adapter.is_connected is True + assert adapter.status == AdapterStatus.HEALTHY + + adapter.disconnect() + assert adapter.is_connected is False + assert adapter.status == AdapterStatus.DISCONNECTED + + def test_get_adapter_info(self) -> None: + adapter = OpenAIAdapter() + info = adapter.get_adapter_info() + assert info.framework == "openai" + assert info.name == "OpenAIAdapter" + assert AdapterCapability.TRACE_MODELS in info.capabilities + assert AdapterCapability.TRACE_TOOLS in info.capabilities + + def test_health_check(self) -> None: + adapter = OpenAIAdapter() + adapter.connect() + h = adapter.health_check() + assert h.framework_name == "openai" + assert h.status == AdapterStatus.HEALTHY + assert h.error_count == 0 + assert h.circuit_open is False + + def test_serialize_for_replay(self) -> None: + adapter = OpenAIAdapter( + stratix=_RecordingStratix(), + capture_config=CaptureConfig.full(), + ) + adapter.connect() + rt = adapter.serialize_for_replay() + assert rt.framework == "openai" + assert "capture_config" in rt.config + + +# --------------------------------------------------------------------------- +# Wrapping chat.completions.create +# --------------------------------------------------------------------------- + + +class TestOpenAIChatWrap: + def test_connect_client_replaces_create(self) -> None: + adapter = OpenAIAdapter() + client = _make_client(returns=_make_response()) + original = client.chat.completions.create + + adapter.connect_client(client) + + assert client.chat.completions.create is not original + assert "chat.completions.create" in adapter._originals + + def test_successful_call_emits_model_invoke_and_cost(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "hello"}], + temperature=0.7, + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["model"] == "gpt-4o" + assert invoke["payload"]["provider"] == "openai" + assert invoke["payload"]["prompt_tokens"] == 10 + assert invoke["payload"]["completion_tokens"] == 5 + assert invoke["payload"]["total_tokens"] == 15 + assert invoke["payload"]["latency_ms"] >= 0 + assert invoke["payload"]["parameters"]["temperature"] == 0.7 + # Output message captured because capture_content=True (full preset). + assert "output_message" in invoke["payload"] + + def test_messages_normalized_into_invoke(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ], + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + msgs = invoke["payload"].get("messages") + assert msgs is not None + assert len(msgs) == 2 + assert msgs[0]["role"] == "system" + assert msgs[1]["role"] == "user" + + def test_capture_content_false_omits_messages(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter( + stratix=stratix, + capture_config=CaptureConfig(capture_content=False), + ) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "secret"}], + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert "messages" not in invoke["payload"] + assert "output_message" not in invoke["payload"] + + def test_response_metadata_captured(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns=_make_response(response_id="resp-42")) + adapter.connect_client(client) + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["response_id"] == "resp-42" + assert invoke["payload"]["finish_reason"] == "stop" + assert invoke["payload"]["system_fingerprint"] == "fp-xyz" + + def test_tool_calls_emitted(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + function = SimpleNamespace( + name="get_weather", + arguments='{"city": "SF"}', + ) + tool_call = SimpleNamespace(id="call-1", function=function) + + client = _make_client(returns=_make_response(tool_calls=[tool_call])) + adapter.connect_client(client) + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "weather"}], + ) + + tool_events = [e for e in stratix.events if e["event_type"] == "tool.call"] + assert len(tool_events) == 1 + assert tool_events[0]["payload"]["tool_name"] == "get_weather" + assert tool_events[0]["payload"]["tool_input"] == {"city": "SF"} + assert tool_events[0]["payload"]["tool_call_id"] == "call-1" + + def test_provider_error_emits_policy_violation(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(raises=RuntimeError("rate limited")) + adapter.connect_client(client) + + with pytest.raises(RuntimeError): + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "policy.violation" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["error"] == "rate limited" + + def test_disconnect_restores_original(self) -> None: + adapter = OpenAIAdapter() + adapter.connect() + + client = _make_client(returns=_make_response()) + original = client.chat.completions.create + adapter.connect_client(client) + assert client.chat.completions.create is not original + + adapter.disconnect() + assert client.chat.completions.create is original + + +# --------------------------------------------------------------------------- +# Wrapping embeddings.create +# --------------------------------------------------------------------------- + + +class TestOpenAIEmbeddingsWrap: + def test_embeddings_emit_invoke_and_cost(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client() + adapter.connect_client(client) + + client.embeddings.create(model="text-embedding-3-small", input="hello") + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"].get("request_type") == "embedding" + + +# --------------------------------------------------------------------------- +# Token usage extraction +# --------------------------------------------------------------------------- + + +class TestUsageExtraction: + def test_extract_from_obj_with_basic_fields(self) -> None: + usage = SimpleNamespace( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + prompt_tokens_details=None, + completion_tokens_details=None, + ) + result = OpenAIAdapter._extract_usage_from_obj(usage) + assert result.prompt_tokens == 100 + assert result.completion_tokens == 50 + assert result.total_tokens == 150 + assert result.cached_tokens is None + + def test_extract_from_obj_with_cached_tokens(self) -> None: + details = SimpleNamespace(cached_tokens=20) + usage = SimpleNamespace( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + prompt_tokens_details=details, + completion_tokens_details=None, + ) + result = OpenAIAdapter._extract_usage_from_obj(usage) + assert result.cached_tokens == 20 + + def test_extract_from_obj_with_reasoning_tokens(self) -> None: + comp_details = SimpleNamespace(reasoning_tokens=30) + usage = SimpleNamespace( + prompt_tokens=100, + completion_tokens=50, + total_tokens=150, + prompt_tokens_details=None, + completion_tokens_details=comp_details, + ) + result = OpenAIAdapter._extract_usage_from_obj(usage) + assert result.reasoning_tokens == 30 + + +# --------------------------------------------------------------------------- +# Capture config gating on model.invoke +# --------------------------------------------------------------------------- + + +class TestCaptureGating: + def test_l3_disabled_drops_model_invoke(self) -> None: + """When L3 model_metadata is off, model.invoke is dropped.""" + stratix = _RecordingStratix() + adapter = OpenAIAdapter( + stratix=stratix, + capture_config=CaptureConfig(l3_model_metadata=False), + ) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" not in types + # cost.record is cross-cutting and STILL emits. + assert "cost.record" in types + + +# --------------------------------------------------------------------------- +# Sink dispatch integration +# --------------------------------------------------------------------------- + + +class _MemorySink(EventSink): + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def send(self, event_type: str, payload: Dict[str, Any], timestamp_ns: int) -> None: + self.events.append({"event_type": event_type, "payload": payload}) + + def flush(self) -> None: # pragma: no cover - no buffering + pass + + def close(self) -> None: # pragma: no cover - nothing to finalize + pass + + +class TestSinkIntegration: + def test_sink_receives_emitted_events(self) -> None: + sink = _MemorySink() + adapter = OpenAIAdapter( + stratix=_RecordingStratix(), + capture_config=CaptureConfig.full(), + ) + adapter.add_sink(sink) + adapter.connect() + + client = _make_client(returns=_make_response()) + adapter.connect_client(client) + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + types = [e["event_type"] for e in sink.events] + assert "model.invoke" in types + assert "cost.record" in types + + +# --------------------------------------------------------------------------- +# Pricing integration (cost.record must include api_cost_usd for known model) +# --------------------------------------------------------------------------- + + +class TestCostCalculation: + def test_known_model_gets_priced_cost(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client( + returns=_make_response(prompt_tokens=1000, completion_tokens=500, total_tokens=1500) + ) + adapter.connect_client(client) + client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "x"}], + ) + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + assert cost["payload"]["api_cost_usd"] is not None + # gpt-4o = 0.0025 input + 0.01 output per 1k => 0.0025 + 0.005 = 0.0075 + assert cost["payload"]["api_cost_usd"] == pytest.approx(0.0075, rel=1e-4) + + def test_unknown_model_marked_pricing_unavailable(self) -> None: + stratix = _RecordingStratix() + adapter = OpenAIAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + + client = _make_client(returns=_make_response(response_model="unobtanium-xyz")) + adapter.connect_client(client) + client.chat.completions.create( + model="unobtanium-xyz", + messages=[{"role": "user", "content": "x"}], + ) + + cost = next(e for e in stratix.events if e["event_type"] == "cost.record") + assert cost["payload"]["api_cost_usd"] is None + assert cost["payload"].get("pricing_unavailable") is True + + +# --------------------------------------------------------------------------- +# NormalizedTokenUsage compute_total / with_auto_total +# --------------------------------------------------------------------------- + + +class TestNormalizedTokenUsage: + def test_with_auto_total_computes_when_zero(self) -> None: + u = NormalizedTokenUsage.with_auto_total(prompt_tokens=10, completion_tokens=5) + assert u.total_tokens == 15 + + def test_with_auto_total_respects_explicit_total(self) -> None: + u = NormalizedTokenUsage.with_auto_total( + prompt_tokens=10, + completion_tokens=5, + total_tokens=99, + ) + assert u.total_tokens == 99 + + def test_compute_total_returns_fresh_instance(self) -> None: + original = NormalizedTokenUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0) + recomputed = original.compute_total() + assert original.total_tokens == 0 + assert recomputed.total_tokens == 15 + assert recomputed is not original diff --git a/tests/instrument/adapters/providers/test_openai_adapter_live.py b/tests/instrument/adapters/providers/test_openai_adapter_live.py new file mode 100644 index 0000000..85022e2 --- /dev/null +++ b/tests/instrument/adapters/providers/test_openai_adapter_live.py @@ -0,0 +1,288 @@ +"""Live OpenAI integration tests for ``OpenAIAdapter``. + +Tests in this module are gated by the ``@pytest.mark.live`` marker +(registered in ``tests/conftest.py``) AND by the presence of an +``OPENAI_API_KEY`` env var. They make REAL calls to the OpenAI API and +incur real cost (single-token completions, chosen to be < $0.0001 per +test). Skip on PR CI; run nightly or on demand: + +:: + + OPENAI_API_KEY=sk-... pytest tests/instrument/adapters/providers/test_openai_adapter_live.py -m live + +These tests exist to catch: + +1. **OpenAI SDK schema drift** — if the SDK renames ``usage.prompt_tokens`` + or removes ``response.system_fingerprint`` or changes the shape of + ``tool_calls``, the mocked tests pass but these will fail. +2. **End-to-end transport** — events flow from the real SDK call through + the adapter into the real HTTP transport sink and reach a live + localhost endpoint that mirrors the atlas-app ingest contract. +3. **Streaming behavior** — the streaming wrapper is exercised against + real chunk sequences, not synthesized ones. + +The tests assert on **structural invariants** (event types fired, +required fields present, costs computed) rather than exact byte values, +so they remain stable as model outputs change. +""" + +from __future__ import annotations + +import os +import json +import time +import threading +from typing import Any, Dict, List, Tuple +from http.server import HTTPServer, BaseHTTPRequestHandler + +import pytest + +from layerlens.instrument.adapters._base import CaptureConfig +from layerlens.instrument.transport.sink_http import HttpEventSink +from layerlens.instrument.adapters.providers.openai_adapter import OpenAIAdapter + +pytestmark = [ + pytest.mark.live, + pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="OPENAI_API_KEY not set; skipping live OpenAI tests", + ), +] + + +@pytest.fixture +def live_openai_client() -> Any: + """Build a real ``openai.OpenAI`` client. + + Skips the test cleanly if the openai package isn't installed. + """ + try: + from openai import OpenAI + except ImportError: + pytest.skip("openai package not installed") + + return OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + +# --------------------------------------------------------------------------- +# Shared local HTTP capture server (mirrors atlas-app ingest contract) +# --------------------------------------------------------------------------- + + +class _IngestRecorder: + def __init__(self) -> None: + self.batches: List[Dict[str, Any]] = [] + self.lock = threading.Lock() + + +def _make_ingest_handler(recorder: _IngestRecorder) -> type: + class _Handler(BaseHTTPRequestHandler): + def log_message(self, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + pass + + def do_POST(self) -> None: # noqa: N802 + length = int(self.headers.get("Content-Length", "0")) + raw = self.rfile.read(length) if length > 0 else b"" + try: + body = json.loads(raw) + except json.JSONDecodeError: + body = {"_raw": raw.decode("utf-8", "replace")} + with recorder.lock: + recorder.batches.append(body) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"ok":true}') + + return _Handler + + +@pytest.fixture +def ingest_server() -> Any: + recorder = _IngestRecorder() + httpd = HTTPServer(("127.0.0.1", 0), _make_ingest_handler(recorder)) + port = httpd.server_address[1] + thread = threading.Thread(target=httpd.serve_forever, daemon=True) + thread.start() + try: + yield f"http://127.0.0.1:{port}", recorder + finally: + httpd.shutdown() + thread.join(timeout=5.0) + httpd.server_close() + + +# --------------------------------------------------------------------------- +# Live tests — real OpenAI, real transport, real localhost ingest server +# --------------------------------------------------------------------------- + + +class TestOpenAIAdapterLive: + def test_real_chat_completion_emits_full_event_set( + self, + live_openai_client: Any, + ingest_server: Tuple[str, _IngestRecorder], + ) -> None: + """A single real ``chat.completions.create`` call must: + + * Reach OpenAI and return a valid response. + * Emit ``model.invoke`` and ``cost.record`` events. + * Route those events through HttpEventSink to the local server. + * Carry usage tokens that match the real response. + """ + base_url, recorder = ingest_server + + sink = HttpEventSink( + adapter_name="openai", + api_key="test-org-key", + base_url=base_url, + path="/telemetry/spans", + max_batch=1, + flush_interval_s=0.0, + ) + + adapter = OpenAIAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + adapter.connect_client(live_openai_client) + + try: + response = live_openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Say hi in one word."}], + max_tokens=5, + ) + finally: + sink.close() + adapter.disconnect() + + # Real OpenAI response shape. + assert response.choices, "OpenAI returned no choices" + assert response.usage is not None, "OpenAI returned no usage" + + # Adapter routed events through the sink to our localhost server. + time.sleep(0.5) # give close() a moment if needed + with recorder.lock: + batches = list(recorder.batches) + assert batches, "no events reached the ingest server" + + all_events: List[Dict[str, Any]] = [] + for batch in batches: + all_events.extend(batch.get("events", [])) + + types = [e["event_type"] for e in all_events] + assert "model.invoke" in types, f"missing model.invoke in {types}" + assert "cost.record" in types, f"missing cost.record in {types}" + + invoke = next(e for e in all_events if e["event_type"] == "model.invoke") + # Real provider field — would FAIL if SDK renamed `usage.prompt_tokens`. + assert invoke["payload"]["prompt_tokens"] == response.usage.prompt_tokens + assert invoke["payload"]["completion_tokens"] == response.usage.completion_tokens + assert invoke["payload"]["total_tokens"] == response.usage.total_tokens + assert invoke["payload"]["latency_ms"] > 0 + assert invoke["payload"]["model"] == "gpt-4o-mini" + + cost = next(e for e in all_events if e["event_type"] == "cost.record") + # gpt-4o-mini IS in the pricing table — must compute a cost. + assert cost["payload"]["api_cost_usd"] is not None + assert cost["payload"]["api_cost_usd"] >= 0 + + def test_real_streaming_emits_consolidated_event( + self, + live_openai_client: Any, + ingest_server: Tuple[str, _IngestRecorder], + ) -> None: + """Streaming consumption must emit exactly one ``model.invoke`` + on stream completion (not one per chunk).""" + base_url, recorder = ingest_server + + sink = HttpEventSink( + adapter_name="openai", + api_key="test-org-key", + base_url=base_url, + max_batch=1, + flush_interval_s=0.0, + ) + + adapter = OpenAIAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + adapter.connect_client(live_openai_client) + + try: + stream = live_openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": "Count to three."}], + max_tokens=20, + stream=True, + stream_options={"include_usage": True}, + ) + chunks_seen = 0 + for _chunk in stream: + chunks_seen += 1 + assert chunks_seen > 0, "stream produced no chunks" + finally: + sink.close() + adapter.disconnect() + + time.sleep(0.5) + with recorder.lock: + batches = list(recorder.batches) + + all_events: List[Dict[str, Any]] = [] + for batch in batches: + all_events.extend(batch.get("events", [])) + + invoke_events = [e for e in all_events if e["event_type"] == "model.invoke"] + # Exactly one model.invoke per LLM call, regardless of chunk count. + assert len(invoke_events) == 1 + # The streaming flag is captured in metadata. + assert invoke_events[0]["payload"].get("streaming") is True + + def test_real_error_path_emits_policy_violation( + self, + live_openai_client: Any, + ingest_server: Tuple[str, _IngestRecorder], + ) -> None: + """An invalid model name produces a real OpenAI error which the + adapter must convert into a ``policy.violation`` event.""" + base_url, recorder = ingest_server + + sink = HttpEventSink( + adapter_name="openai", + api_key="test-org-key", + base_url=base_url, + max_batch=1, + flush_interval_s=0.0, + ) + + adapter = OpenAIAdapter(capture_config=CaptureConfig.standard()) + adapter.add_sink(sink) + adapter.connect() + adapter.connect_client(live_openai_client) + + try: + with pytest.raises(Exception): # noqa: B017 - OpenAI raises one of several SDK error types + live_openai_client.chat.completions.create( + model="this-model-definitely-does-not-exist-xyz123", + messages=[{"role": "user", "content": "x"}], + max_tokens=1, + ) + finally: + sink.close() + adapter.disconnect() + + time.sleep(0.5) + with recorder.lock: + batches = list(recorder.batches) + + all_events: List[Dict[str, Any]] = [] + for batch in batches: + all_events.extend(batch.get("events", [])) + + types = [e["event_type"] for e in all_events] + assert "model.invoke" in types # error variant + assert "policy.violation" in types + + invoke = next(e for e in all_events if e["event_type"] == "model.invoke") + assert "error" in invoke["payload"] diff --git a/tests/instrument/adapters/providers/test_vertex_adapter.py b/tests/instrument/adapters/providers/test_vertex_adapter.py new file mode 100644 index 0000000..8b88a13 --- /dev/null +++ b/tests/instrument/adapters/providers/test_vertex_adapter.py @@ -0,0 +1,110 @@ +"""Unit tests for the Google Vertex AI provider adapter.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, List + +from layerlens.instrument.adapters._base import AdapterStatus, CaptureConfig +from layerlens.instrument.adapters.providers.google_vertex_adapter import ( + ADAPTER_CLASS, + GoogleVertexAdapter, +) + + +class _RecordingStratix: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + def emit(self, *args: Any, **kwargs: Any) -> None: + if len(args) == 2 and isinstance(args[0], str): + self.events.append({"event_type": args[0], "payload": args[1]}) + + +def _make_response(text: str = "hello") -> Any: + part = SimpleNamespace(text=text) + content = SimpleNamespace(parts=[part]) + finish = SimpleNamespace(name="STOP") + candidate = SimpleNamespace(content=content, finish_reason=finish) + metadata = SimpleNamespace( + prompt_token_count=10, + candidates_token_count=5, + total_token_count=15, + thoughts_token_count=None, + ) + return SimpleNamespace(candidates=[candidate], usage_metadata=metadata) + + +def _make_model_client(model_name: str = "gemini-1.5-pro") -> Any: + client = SimpleNamespace(model_name=model_name) + client.generate_content = lambda *a, **kw: _make_response() + return client + + +def test_adapter_class_export() -> None: + assert ADAPTER_CLASS is GoogleVertexAdapter + + +def test_lifecycle() -> None: + a = GoogleVertexAdapter() + a.connect() + assert a.status == AdapterStatus.HEALTHY + + +def test_normalize_string_contents() -> None: + msgs = GoogleVertexAdapter._normalize_vertex_contents("Hello world") + assert msgs == [{"role": "user", "content": "Hello world"}] + + +def test_normalize_list_of_strings() -> None: + msgs = GoogleVertexAdapter._normalize_vertex_contents(["First", "Second"]) + assert msgs == [ + {"role": "user", "content": "First"}, + {"role": "user", "content": "Second"}, + ] + + +def test_extract_function_calls() -> None: + fn = SimpleNamespace(name="get_weather", args={"city": "SF"}) + part = SimpleNamespace(function_call=fn, text=None) + content = SimpleNamespace(parts=[part]) + candidate = SimpleNamespace(content=content) + response = SimpleNamespace(candidates=[candidate]) + + calls = GoogleVertexAdapter._extract_function_calls(response) + assert len(calls) == 1 + assert calls[0]["name"] == "get_weather" + assert calls[0]["arguments"] == {"city": "SF"} + + +def test_emits_invoke_and_cost() -> None: + stratix = _RecordingStratix() + adapter = GoogleVertexAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_model_client() + adapter.connect_client(client) + + client.generate_content("hello") + + types = [e["event_type"] for e in stratix.events] + assert "model.invoke" in types + assert "cost.record" in types + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["provider"] == "google_vertex" + assert invoke["payload"]["model"] == "gemini-1.5-pro" + assert invoke["payload"]["prompt_tokens"] == 10 + assert invoke["payload"]["finish_reason"] == "STOP" + + +def test_strips_models_prefix() -> None: + """``models/gemini-1.5-pro`` → ``gemini-1.5-pro`` for pricing lookup.""" + stratix = _RecordingStratix() + adapter = GoogleVertexAdapter(stratix=stratix, capture_config=CaptureConfig.full()) + adapter.connect() + client = _make_model_client(model_name="models/gemini-1.5-pro") + adapter.connect_client(client) + client.generate_content("hi") + + invoke = next(e for e in stratix.events if e["event_type"] == "model.invoke") + assert invoke["payload"]["model"] == "gemini-1.5-pro" diff --git a/tests/instrument/adapters/test_pydantic_compat.py b/tests/instrument/adapters/test_pydantic_compat.py new file mode 100644 index 0000000..2028c57 --- /dev/null +++ b/tests/instrument/adapters/test_pydantic_compat.py @@ -0,0 +1,261 @@ +"""Tests for the per-adapter Pydantic v1/v2 compatibility matrix. + +Round-2 deliberation item 20. Three behavioral guarantees: + +1. Every framework adapter declares ``requires_pydantic`` as one of the + three :class:`PydanticCompat` enum values. +2. Every framework adapter sets the value *explicitly* on its subclass + (not relying on the :class:`BaseAdapter` default) so the determination + is deliberate, not accidental. +3. :func:`requires_pydantic` raises :class:`RuntimeError` with a clear + message when the runtime Pydantic does not match an adapter's + declaration. +""" + +from __future__ import annotations + +from typing import Set, List, Type +from unittest import mock + +import pytest + +from layerlens.instrument.adapters._base import ( + BaseAdapter, + PydanticCompat, + requires_pydantic, +) + +# Frameworks whose adapter classes are expected to declare +# requires_pydantic explicitly. Keep this list aligned with the registry's +# ``_ADAPTER_MODULES`` framework subset (excluding providers/protocols +# which are pydantic-agnostic and inherit the default). +# +# ``benchmark_import`` is intentionally absent because its +# ``BenchmarkImportAdapter`` does NOT subclass :class:`BaseAdapter` (it +# never registered an ``ADAPTER_CLASS``). ``langfuse_importer`` and +# ``browser_use`` registry entries point at module paths that don't +# exist on disk; the registry handles that defensively. +_FRAMEWORK_ADAPTERS: List[str] = [ + "langgraph", + "langchain", + "crewai", + "autogen", + "semantic_kernel", + "langfuse", + "openai_agents", + "google_adk", + "bedrock_agents", + "pydantic_ai", + "llama_index", + "smolagents", + "agno", + "strands", + "ms_agent_framework", + "salesforce_agentforce", + "embedding", +] + + +def _import_adapter_class(framework: str) -> Type[BaseAdapter]: + """Import the adapter class for a given framework. + + Skips adapters whose import-time pydantic compat check would fail + under the active runtime — those are tested in their own dedicated + test which mocks ``PYDANTIC_V2``. + """ + from layerlens.instrument.adapters._base.registry import _ADAPTER_MODULES + + module_path = _ADAPTER_MODULES[framework] + import importlib + + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError: + # The registry references a module that doesn't exist on disk + # (pre-existing for ``langfuse_importer`` and ``browser_use``). + # Fall back to the package path matching the framework name. + fallback = f"layerlens.instrument.adapters.frameworks.{framework}" + module = importlib.import_module(fallback) + cls = getattr(module, "ADAPTER_CLASS", None) + if cls is None: + pytest.skip(f"{framework} has no ADAPTER_CLASS — not a registered adapter") + if not isinstance(cls, type) or not issubclass(cls, BaseAdapter): + pytest.skip(f"{framework}.ADAPTER_CLASS is not a BaseAdapter subclass") + return cls + + +# --------------------------------------------------------------------------- +# Test 1: every framework adapter declares one of the three values +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("framework", _FRAMEWORK_ADAPTERS) +def test_adapter_declares_valid_compat(framework: str) -> None: + """Every registered framework adapter declares ``requires_pydantic``.""" + cls = _import_adapter_class(framework) + declared = cls.requires_pydantic + assert isinstance(declared, PydanticCompat), ( + f"{framework}.requires_pydantic must be a PydanticCompat enum, got {type(declared).__name__}: {declared!r}" + ) + assert declared in { + PydanticCompat.V1_ONLY, + PydanticCompat.V2_ONLY, + PydanticCompat.V1_OR_V2, + } + + +# --------------------------------------------------------------------------- +# Test 2: lint — every framework adapter sets the attribute explicitly, +# not relying on the BaseAdapter default by accident. +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("framework", _FRAMEWORK_ADAPTERS) +def test_adapter_sets_compat_explicitly(framework: str) -> None: + """Every framework adapter must override ``requires_pydantic`` itself. + + Walks the MRO and checks that ``requires_pydantic`` appears in the + adapter subclass's own ``__dict__`` (or that of an intermediate + framework-specific base class), not only on + :class:`BaseAdapter`. Guards against a future framework adapter + silently inheriting V1_OR_V2 when it should declare V2_ONLY. + """ + cls = _import_adapter_class(framework) + + declaring_classes: Set[str] = set() + for klass in cls.__mro__: + if klass is BaseAdapter: + break + if "requires_pydantic" in klass.__dict__: + declaring_classes.add(klass.__name__) + + assert declaring_classes, ( + f"{framework} adapter ({cls.__name__}) does not set " + "``requires_pydantic`` on its own class. Add an explicit declaration " + "(V1_ONLY, V2_ONLY, or V1_OR_V2) — relying on the BaseAdapter " + "default is forbidden by the Round-2 item 20 lint." + ) + + +# --------------------------------------------------------------------------- +# Test 3: requires_pydantic() raises RuntimeError with a clear message +# when the runtime Pydantic doesn't match. +# --------------------------------------------------------------------------- + + +def test_requires_pydantic_v1_or_v2_never_raises() -> None: + """``V1_OR_V2`` declarations are always accepted.""" + requires_pydantic(PydanticCompat.V1_OR_V2) # must not raise + + +def test_requires_pydantic_v2_only_raises_under_v1() -> None: + """A V2_ONLY declaration raises ``RuntimeError`` under v1 runtime.""" + with mock.patch( + "layerlens.instrument.adapters._base.pydantic_compat.PYDANTIC_V2", + False, + ): + with pytest.raises(RuntimeError) as exc_info: + requires_pydantic(PydanticCompat.V2_ONLY) + msg = str(exc_info.value) + assert "Pydantic v2" in msg + assert "v2_only" in msg + # Message must include actionable guidance. + assert "pip install" in msg + + +def test_requires_pydantic_v1_only_raises_under_v2() -> None: + """A V1_ONLY declaration raises ``RuntimeError`` under v2 runtime.""" + with mock.patch( + "layerlens.instrument.adapters._base.pydantic_compat.PYDANTIC_V2", + True, + ): + with pytest.raises(RuntimeError) as exc_info: + requires_pydantic(PydanticCompat.V1_ONLY) + msg = str(exc_info.value) + assert "Pydantic v1" in msg + assert "v1_only" in msg + assert "pip install" in msg + + +def test_requires_pydantic_v2_only_passes_under_v2() -> None: + """A V2_ONLY declaration is accepted under v2 runtime.""" + with mock.patch( + "layerlens.instrument.adapters._base.pydantic_compat.PYDANTIC_V2", + True, + ): + requires_pydantic(PydanticCompat.V2_ONLY) # must not raise + + +def test_requires_pydantic_v1_only_passes_under_v1() -> None: + """A V1_ONLY declaration is accepted under v1 runtime.""" + with mock.patch( + "layerlens.instrument.adapters._base.pydantic_compat.PYDANTIC_V2", + False, + ): + requires_pydantic(PydanticCompat.V1_ONLY) # must not raise + + +def test_requires_pydantic_message_includes_caller_module() -> None: + """Error message names the calling adapter module for actionability.""" + + # Wrap the call in a function so the caller module is detectable. + def _shim() -> None: + requires_pydantic(PydanticCompat.V2_ONLY) + + # Force the helper down its raise path. + with mock.patch( + "layerlens.instrument.adapters._base.pydantic_compat.PYDANTIC_V2", + False, + ): + with pytest.raises(RuntimeError) as exc_info: + _shim() + # The shim lives in this test module; the helper walks one frame + # back to identify the caller of requires_pydantic itself. + assert __name__ in str(exc_info.value) + + +# --------------------------------------------------------------------------- +# Test 4: AdapterInfo surfaces the class-level requires_pydantic via .info() +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("framework", _FRAMEWORK_ADAPTERS) +def test_adapter_info_exposes_compat(framework: str) -> None: + """``BaseAdapter.info()`` reflects the class-level declaration.""" + cls = _import_adapter_class(framework) + try: + instance = cls() # type: ignore[call-arg] + except TypeError: + pytest.skip(f"{framework} adapter cannot be instantiated with no args") + + info_obj = instance.info() + assert info_obj.requires_pydantic == cls.requires_pydantic + + +# --------------------------------------------------------------------------- +# Test 5: documented expectations for the V2_ONLY frameworks +# --------------------------------------------------------------------------- + + +_EXPECTED_V2_ONLY: Set[str] = { + "langchain", + "langgraph", + "crewai", + "pydantic_ai", + "langfuse", +} + + +@pytest.mark.parametrize("framework", sorted(_EXPECTED_V2_ONLY)) +def test_known_v2_only_frameworks(framework: str) -> None: + """Document expected V2_ONLY status for the well-known cases. + + A regression in this matrix (e.g., loosening langchain to V1_OR_V2) + fails this test loudly — the determinations were made deliberately + based on framework version pins and source imports, not by accident. + """ + cls = _import_adapter_class(framework) + assert cls.requires_pydantic is PydanticCompat.V2_ONLY, ( + f"{framework} is expected V2_ONLY (see adapter docstring for the " + f"specific Pydantic v2 imports / framework pin that justifies it)" + )