From 9c9ff9ed7a7886f85d2ccde935d3835c51d52fc4 Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Thu, 19 Mar 2026 21:43:41 -0700 Subject: [PATCH 1/3] Add layerlens.instrument, tracing core + initial adapters --- examples/instrument_langchain.py | 30 +++ examples/instrument_openai.py | 46 ++++ pyproject.toml | 6 + src/layerlens/instrument/__init__.py | 13 ++ src/layerlens/instrument/_context.py | 11 + src/layerlens/instrument/_decorator.py | 93 ++++++++ src/layerlens/instrument/_recorder.py | 22 ++ src/layerlens/instrument/_span.py | 43 ++++ src/layerlens/instrument/_types.py | 44 ++++ src/layerlens/instrument/_upload.py | 35 +++ src/layerlens/instrument/adapters/__init__.py | 1 + .../adapters/frameworks/__init__.py | 1 + .../adapters/frameworks/_base_framework.py | 69 ++++++ .../adapters/frameworks/langchain.py | 215 +++++++++++++++++ .../adapters/frameworks/langgraph.py | 39 ++++ .../instrument/adapters/providers/__init__.py | 1 + .../adapters/providers/_base_provider.py | 73 ++++++ .../adapters/providers/anthropic.py | 120 ++++++++++ .../instrument/adapters/providers/litellm.py | 83 +++++++ .../instrument/adapters/providers/openai.py | 138 +++++++++++ tests/instrument/__init__.py | 0 tests/instrument/conftest.py | 26 +++ tests/instrument/test_adapters.py | 130 +++++++++++ tests/instrument/test_core.py | 163 +++++++++++++ tests/instrument/test_providers.py | 217 ++++++++++++++++++ tests/instrument/test_types.py | 58 +++++ 26 files changed, 1677 insertions(+) create mode 100644 examples/instrument_langchain.py create mode 100644 examples/instrument_openai.py create mode 100644 src/layerlens/instrument/__init__.py create mode 100644 src/layerlens/instrument/_context.py create mode 100644 src/layerlens/instrument/_decorator.py create mode 100644 src/layerlens/instrument/_recorder.py create mode 100644 src/layerlens/instrument/_span.py create mode 100644 src/layerlens/instrument/_types.py create mode 100644 src/layerlens/instrument/_upload.py create mode 100644 src/layerlens/instrument/adapters/__init__.py create mode 100644 src/layerlens/instrument/adapters/frameworks/__init__.py create mode 100644 src/layerlens/instrument/adapters/frameworks/_base_framework.py create mode 100644 src/layerlens/instrument/adapters/frameworks/langchain.py create mode 100644 src/layerlens/instrument/adapters/frameworks/langgraph.py create mode 100644 src/layerlens/instrument/adapters/providers/__init__.py create mode 100644 src/layerlens/instrument/adapters/providers/_base_provider.py create mode 100644 src/layerlens/instrument/adapters/providers/anthropic.py create mode 100644 src/layerlens/instrument/adapters/providers/litellm.py create mode 100644 src/layerlens/instrument/adapters/providers/openai.py create mode 100644 tests/instrument/__init__.py create mode 100644 tests/instrument/conftest.py create mode 100644 tests/instrument/test_adapters.py create mode 100644 tests/instrument/test_core.py create mode 100644 tests/instrument/test_providers.py create mode 100644 tests/instrument/test_types.py diff --git a/examples/instrument_langchain.py b/examples/instrument_langchain.py new file mode 100644 index 0000000..e19a515 --- /dev/null +++ b/examples/instrument_langchain.py @@ -0,0 +1,30 @@ +"""Example: Instrument a LangChain chain with automatic span capture. + +Requires: + pip install layerlens[langchain] langchain-openai + export LAYERLENS_STRATIX_API_KEY="your-api-key" + export OPENAI_API_KEY="your-openai-key" +""" + +from langchain_openai import ChatOpenAI +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser + +from layerlens import Stratix +from layerlens.instrument.adapters.frameworks.langchain import LangChainCallbackHandler + +client = Stratix() +handler = LangChainCallbackHandler(client) + +# Build a simple chain +prompt = ChatPromptTemplate.from_template("Answer this question concisely: {question}") +llm = ChatOpenAI(model="gpt-4o") +chain = prompt | llm | StrOutputParser() + +if __name__ == "__main__": + # The callback handler captures the full chain execution as a trace + result = chain.invoke( + {"question": "What is retrieval-augmented generation?"}, + config={"callbacks": [handler]}, + ) + print(f"Answer: {result}") diff --git a/examples/instrument_openai.py b/examples/instrument_openai.py new file mode 100644 index 0000000..92118a1 --- /dev/null +++ b/examples/instrument_openai.py @@ -0,0 +1,46 @@ +"""Example: Instrument OpenAI with automatic LLM span capture. + +Requires: + pip install layerlens[openai] + export LAYERLENS_STRATIX_API_KEY="your-api-key" + export OPENAI_API_KEY="your-openai-key" +""" + +import openai +from layerlens import Stratix +from layerlens.instrument import span, trace +from layerlens.instrument.adapters.providers.openai import instrument_openai + +client = Stratix() +openai_client = openai.OpenAI() + +# Instrument the OpenAI client — all chat.completions.create calls +# inside a @trace will generate LLM spans automatically. +instrument_openai(openai_client) + + +@trace(client) +def qa_agent(question: str): + """Simple Q&A agent with a retrieval step and an LLM call.""" + + # Manual span for a retrieval step + with span("retrieve", kind="retriever") as s: + # In a real app, this would query a vector database + docs = ["Python is a programming language.", "It was created by Guido van Rossum."] + s.output = docs + + # The OpenAI call is automatically instrumented — no span() needed + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": f"Answer using this context: {docs}"}, + {"role": "user", "content": question}, + ], + ) + + return response.choices[0].message.content + + +if __name__ == "__main__": + answer = qa_agent("What is Python and who created it?") + print(f"Answer: {answer}") diff --git a/pyproject.toml b/pyproject.toml index fc8baa6..30e15da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,10 @@ classifiers = [ [project.optional-dependencies] cli = ["click>=8.0.0"] +openai = ["openai>=1.0.0"] +anthropic = ["anthropic>=0.18.0"] +langchain = ["langchain-core>=0.1.0"] +litellm = ["litellm>=1.0.0"] [project.urls] Homepage = "https://github.com/LayerLens/stratix-python" @@ -138,6 +142,8 @@ known-first-party = ["openai", "tests"] "tests/**.py" = ["T201", "T203"] "examples/**.py" = ["T201", "T203"] "src/layerlens/cli/**" = ["T201", "T203"] +"src/layerlens/instrument/adapters/frameworks/langchain.py" = ["ARG002"] +"src/layerlens/instrument/adapters/frameworks/langgraph.py" = ["ARG002"] [tool.pyright] include = ["src", "tests"] diff --git a/src/layerlens/instrument/__init__.py b/src/layerlens/instrument/__init__.py new file mode 100644 index 0000000..2e11b51 --- /dev/null +++ b/src/layerlens/instrument/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from ._span import span +from ._types import SpanData +from ._recorder import TraceRecorder +from ._decorator import trace + +__all__ = [ + "SpanData", + "TraceRecorder", + "span", + "trace", +] diff --git a/src/layerlens/instrument/_context.py b/src/layerlens/instrument/_context.py new file mode 100644 index 0000000..b4328f3 --- /dev/null +++ b/src/layerlens/instrument/_context.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional +from contextvars import ContextVar + +if TYPE_CHECKING: + from ._types import SpanData + from ._recorder import TraceRecorder + +_current_recorder: ContextVar[Optional[TraceRecorder]] = ContextVar("_current_recorder", default=None) +_current_span: ContextVar[Optional[SpanData]] = ContextVar("_current_span", default=None) diff --git a/src/layerlens/instrument/_decorator.py b/src/layerlens/instrument/_decorator.py new file mode 100644 index 0000000..4f4644f --- /dev/null +++ b/src/layerlens/instrument/_decorator.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import asyncio +import functools +from typing import Any, Dict, Tuple, Callable, Optional + +from ._types import SpanData +from ._context import _current_span, _current_recorder +from ._recorder import TraceRecorder + + +def trace( + client: Any, + *, + name: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Callable[..., Any]: + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + span_name = name or fn.__name__ + + if asyncio.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + recorder = TraceRecorder(client) + root = SpanData( + name=span_name, + kind="chain", + input=_capture_input(args, kwargs), + metadata=metadata or {}, + ) + recorder.root = root + + rec_token = _current_recorder.set(recorder) + span_token = _current_span.set(root) + try: + result = await fn(*args, **kwargs) + root.output = result + root.finish() + await recorder.async_flush() + return result + except Exception as exc: + root.finish(error=str(exc)) + await recorder.async_flush() + raise + finally: + _current_span.reset(span_token) + _current_recorder.reset(rec_token) + + return async_wrapper + else: + + @functools.wraps(fn) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + recorder = TraceRecorder(client) + root = SpanData( + name=span_name, + kind="chain", + input=_capture_input(args, kwargs), + metadata=metadata or {}, + ) + recorder.root = root + + rec_token = _current_recorder.set(recorder) + span_token = _current_span.set(root) + try: + result = fn(*args, **kwargs) + root.output = result + root.finish() + recorder.flush() + return result + except Exception as exc: + root.finish(error=str(exc)) + recorder.flush() + raise + finally: + _current_span.reset(span_token) + _current_recorder.reset(rec_token) + + return sync_wrapper + + return decorator + + +def _capture_input(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + if args and kwargs: + return {"args": list(args), "kwargs": kwargs} + if args: + arg_list = list(args) + return arg_list if len(arg_list) > 1 else arg_list[0] + if kwargs: + return kwargs + return None diff --git a/src/layerlens/instrument/_recorder.py b/src/layerlens/instrument/_recorder.py new file mode 100644 index 0000000..dba6a45 --- /dev/null +++ b/src/layerlens/instrument/_recorder.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import Any, Optional + +from ._types import SpanData +from ._upload import upload_trace, async_upload_trace + + +class TraceRecorder: + def __init__(self, client: Any) -> None: + self._client = client + self.root: Optional[SpanData] = None + + def flush(self) -> None: + if self.root is None: + return + upload_trace(self._client, self.root.to_dict()) + + async def async_flush(self) -> None: + if self.root is None: + return + await async_upload_trace(self._client, self.root.to_dict()) diff --git a/src/layerlens/instrument/_span.py b/src/layerlens/instrument/_span.py new file mode 100644 index 0000000..0c929ff --- /dev/null +++ b/src/layerlens/instrument/_span.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional, Generator +from contextlib import contextmanager + +from ._types import SpanData +from ._context import _current_span, _current_recorder + + +@contextmanager +def span( + name: str, + *, + kind: str = "internal", + input: Any = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Generator[SpanData, None, None]: + recorder = _current_recorder.get() + parent = _current_span.get() + + if recorder is None or parent is None: + yield SpanData(name=name, kind=kind, input=input, metadata=metadata or {}) + return + + s = SpanData( + name=name, + kind=kind, + parent_id=parent.span_id, + input=input, + metadata=metadata or {}, + ) + parent.children.append(s) + + token = _current_span.set(s) + try: + yield s + except Exception as exc: + s.finish(error=str(exc)) + raise + else: + s.finish() + finally: + _current_span.reset(token) diff --git a/src/layerlens/instrument/_types.py b/src/layerlens/instrument/_types.py new file mode 100644 index 0000000..b589ef0 --- /dev/null +++ b/src/layerlens/instrument/_types.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import time +import uuid +from typing import Any, Dict, List, Optional +from dataclasses import field, dataclass + + +@dataclass +class SpanData: + name: str + span_id: str = field(default_factory=lambda: uuid.uuid4().hex[:16]) + parent_id: Optional[str] = None + start_time: float = field(default_factory=time.time) + end_time: Optional[float] = None + status: str = "ok" + kind: str = "internal" + input: Any = None + output: Any = None + error: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + children: List[SpanData] = field(default_factory=list) + + def finish(self, error: Optional[str] = None) -> None: + self.end_time = time.time() + if error is not None: + self.error = error + self.status = "error" + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "span_id": self.span_id, + "parent_id": self.parent_id, + "start_time": self.start_time, + "end_time": self.end_time, + "status": self.status, + "kind": self.kind, + "input": self.input, + "output": self.output, + "error": self.error, + "metadata": self.metadata, + "children": [c.to_dict() for c in self.children], + } diff --git a/src/layerlens/instrument/_upload.py b/src/layerlens/instrument/_upload.py new file mode 100644 index 0000000..6597970 --- /dev/null +++ b/src/layerlens/instrument/_upload.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import os +import json +import logging +import tempfile +from typing import Any, Dict + +log: logging.Logger = logging.getLogger(__name__) + + +def upload_trace(client: Any, trace_data: Dict[str, Any]) -> None: + fd, path = tempfile.mkstemp(suffix=".json", prefix="layerlens_trace_") + try: + with os.fdopen(fd, "w") as f: + json.dump([trace_data], f, default=str) + client.traces.upload(path) + finally: + try: + os.unlink(path) + except OSError: + log.debug("Failed to remove temp trace file: %s", path) + + +async def async_upload_trace(client: Any, trace_data: Dict[str, Any]) -> None: + fd, path = tempfile.mkstemp(suffix=".json", prefix="layerlens_trace_") + try: + with os.fdopen(fd, "w") as f: + json.dump([trace_data], f, default=str) + await client.traces.upload(path) + finally: + try: + os.unlink(path) + except OSError: + log.debug("Failed to remove temp trace file: %s", path) diff --git a/src/layerlens/instrument/adapters/__init__.py b/src/layerlens/instrument/adapters/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/layerlens/instrument/adapters/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/layerlens/instrument/adapters/frameworks/__init__.py b/src/layerlens/instrument/adapters/frameworks/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/layerlens/instrument/adapters/frameworks/_base_framework.py b/src/layerlens/instrument/adapters/frameworks/_base_framework.py new file mode 100644 index 0000000..3c3ea3a --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/_base_framework.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from uuid import UUID +from typing import Any, Dict, Optional + +from ..._types import SpanData +from ..._upload import upload_trace + + +class FrameworkTracer: + def __init__(self, client: Any) -> None: + self._client = client + self._spans: Dict[str, SpanData] = {} + self._root_run_id: Optional[str] = None + + def _get_or_create_span( + self, + run_id: UUID, + parent_run_id: Optional[UUID], + name: str, + kind: str, + input: Any = None, + ) -> SpanData: + rid = str(run_id) + if rid in self._spans: + return self._spans[rid] + + parent_span: Optional[SpanData] = None + if parent_run_id is not None: + parent_span = self._spans.get(str(parent_run_id)) + + s = SpanData( + name=name, + kind=kind, + parent_id=parent_span.span_id if parent_span else None, + input=input, + ) + self._spans[rid] = s + + if parent_span is not None: + parent_span.children.append(s) + + if self._root_run_id is None: + self._root_run_id = rid + + return s + + def _finish_span(self, run_id: UUID, output: Any = None, error: Optional[str] = None) -> None: + rid = str(run_id) + s = self._spans.get(rid) + if s is None: + return + s.output = output + s.finish(error=error) + + if rid == self._root_run_id: + self._flush() + + def _flush(self) -> None: + if self._root_run_id is None: + return + root = self._spans.get(self._root_run_id) + if root is None: + return + + upload_trace(self._client, root.to_dict()) + + self._spans.clear() + self._root_run_id = None diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py new file mode 100644 index 0000000..5a213a9 --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from uuid import UUID +from typing import Any, Dict, List, Optional, Sequence + +from ._base_framework import FrameworkTracer + +try: + from langchain_core.callbacks import BaseCallbackHandler # pyright: ignore[reportAssignmentType] +except ImportError: + + class BaseCallbackHandler: # type: ignore[no-redef,misc] + def __init_subclass__(cls, **kwargs: Any) -> None: + raise ImportError( + "The 'langchain-core' package is required for LangChain instrumentation. " + "Install it with: pip install layerlens[langchain]" + ) + + +class LangChainCallbackHandler(BaseCallbackHandler, FrameworkTracer): + def __init__(self, client: Any) -> None: + BaseCallbackHandler.__init__(self) + FrameworkTracer.__init__(self, client) + + # -- Chain -- + + def on_chain_start( + self, + serialized: Optional[Dict[str, Any]], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + serialized = serialized or {} + name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] + self._get_or_create_span(run_id, parent_run_id, name=name, kind="chain", input=inputs) + + def on_chain_end( + self, + outputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self._finish_span(run_id, output=outputs) + + def on_chain_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self._finish_span(run_id, error=str(error)) + + # -- LLM -- + + def on_llm_start( + self, + serialized: Optional[Dict[str, Any]], + prompts: List[str], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + serialized = serialized or {} + name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] + self._get_or_create_span(run_id, parent_run_id, name=name, kind="llm", input=prompts) + + def on_chat_model_start( + self, + serialized: Optional[Dict[str, Any]], + messages: List[List[Any]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + serialized = serialized or {} + name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] + input_data = [[_serialize_lc_message(m) for m in batch] for batch in messages] + self._get_or_create_span(run_id, parent_run_id, name=name, kind="llm", input=input_data) + + def on_llm_end( + self, + response: Any, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + output = None + try: + generations = response.generations + if generations and generations[0]: + output = generations[0][0].text + except (AttributeError, IndexError): + pass + + s = self._spans.get(str(run_id)) + if s is not None: + try: + llm_output = response.llm_output + if llm_output: + if "token_usage" in llm_output: + s.metadata["usage"] = llm_output["token_usage"] + if "model_name" in llm_output: + s.metadata["model"] = llm_output["model_name"] + except AttributeError: + pass + + self._finish_span(run_id, output=output) + + def on_llm_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self._finish_span(run_id, error=str(error)) + + # -- Tool -- + + def on_tool_start( + self, + serialized: Optional[Dict[str, Any]], + input_str: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + name = (serialized or {}).get("name", "tool") + self._get_or_create_span(run_id, parent_run_id, name=name, kind="tool", input=input_str) + + def on_tool_end( + self, + output: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self._finish_span(run_id, output=output) + + def on_tool_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self._finish_span(run_id, error=str(error)) + + # -- Retriever -- + + def on_retriever_start( + self, + serialized: Optional[Dict[str, Any]], + query: str, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + name = (serialized or {}).get("name", "retriever") + self._get_or_create_span(run_id, parent_run_id, name=name, kind="retriever", input=query) + + def on_retriever_end( + self, + documents: Sequence[Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + output = [_serialize_lc_document(d) for d in documents] + self._finish_span(run_id, output=output) + + def on_retriever_error( + self, + error: BaseException, + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + self._finish_span(run_id, error=str(error)) + + # -- Text (required by base) -- + + def on_text(self, text: str, **kwargs: Any) -> None: + pass + + +def _serialize_lc_message(msg: Any) -> Any: + try: + return {"type": msg.type, "content": msg.content} + except AttributeError: + return str(msg) + + +def _serialize_lc_document(doc: Any) -> Any: + try: + return {"page_content": doc.page_content, "metadata": doc.metadata} + except AttributeError: + return str(doc) diff --git a/src/layerlens/instrument/adapters/frameworks/langgraph.py b/src/layerlens/instrument/adapters/frameworks/langgraph.py new file mode 100644 index 0000000..1d72bab --- /dev/null +++ b/src/layerlens/instrument/adapters/frameworks/langgraph.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from uuid import UUID +from typing import Any, Dict, List, Optional + +from .langchain import LangChainCallbackHandler + + +class LangGraphCallbackHandler(LangChainCallbackHandler): + def on_chain_start( + self, + serialized: Optional[Dict[str, Any]], + inputs: Dict[str, Any], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + serialized = serialized or {} + name = serialized.get("name") or serialized.get("id", ["unknown"])[-1] + + # Extract node name from LangGraph tags + if tags: + for tag in tags: + if isinstance(tag, str) and tag.startswith("graph:step:"): + continue + if isinstance(tag, str) and ":" not in tag: + name = tag + break + + # Check kwargs for langgraph-specific metadata + metadata = kwargs.get("metadata", {}) + if isinstance(metadata, dict): + node_name = metadata.get("langgraph_node") + if node_name: + name = node_name + + self._get_or_create_span(run_id, parent_run_id, name=name, kind="chain", input=inputs) diff --git a/src/layerlens/instrument/adapters/providers/__init__.py b/src/layerlens/instrument/adapters/providers/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/layerlens/instrument/adapters/providers/_base_provider.py b/src/layerlens/instrument/adapters/providers/_base_provider.py new file mode 100644 index 0000000..f01a935 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/_base_provider.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple, Callable, Optional + +from ..._types import SpanData +from ..._context import _current_span, _current_recorder + + +def create_llm_span( + name: str, + kwargs: Dict[str, Any], + capture_params: frozenset[str], +) -> Tuple[Optional[SpanData], Any]: + recorder = _current_recorder.get() + parent = _current_span.get() + + if recorder is None or parent is None: + return None, None + + meta = {k: kwargs[k] for k in capture_params if k in kwargs} + + s = SpanData( + name=name, + kind="llm", + parent_id=parent.span_id, + input=_extract_messages(kwargs), + metadata=meta, + ) + parent.children.append(s) + token = _current_span.set(s) + return s, token + + +def finish_llm_span( + span: SpanData, + token: Any, + response: Any, + extract_output: Callable[[Any], Any], + extract_meta: Callable[[Any], Dict[str, Any]], +) -> None: + try: + span.output = extract_output(response) + span.metadata.update(extract_meta(response)) + span.finish() + finally: + _current_span.reset(token) + + +def fail_llm_span(span: SpanData, token: Any, error: Exception) -> None: + try: + span.finish(error=str(error)) + finally: + _current_span.reset(token) + + +def _extract_messages(kwargs: Dict[str, Any]) -> Any: + messages = kwargs.get("messages") + if messages is not None: + return [_serialize_message(m) for m in messages] + for key in ("prompt", "contents", "input"): + val = kwargs.get(key) + if val is not None: + return val + return None + + +def _serialize_message(msg: Any) -> Any: + if isinstance(msg, dict): + return msg + try: + return {"role": msg.role, "content": msg.content} + except AttributeError: + return str(msg) diff --git a/src/layerlens/instrument/adapters/providers/anthropic.py b/src/layerlens/instrument/adapters/providers/anthropic.py new file mode 100644 index 0000000..72be2c9 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/anthropic.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from ._base_provider import fail_llm_span, create_llm_span, finish_llm_span + +log: logging.Logger = logging.getLogger(__name__) + +_CAPTURE_PARAMS = frozenset( + { + "model", + "max_tokens", + "temperature", + "top_p", + "top_k", + "system", + "tool_choice", + } +) + + +class AnthropicProvider: + def __init__(self) -> None: + self._client: Any = None + self._originals: Dict[str, Any] = {} + + def connect_client(self, client: Any) -> Any: + self._client = client + + if hasattr(client, "messages"): + orig = client.messages.create + self._originals["messages.create"] = orig + client.messages.create = self._wrap_sync(orig) + + return client + + def disconnect(self) -> None: + if self._client is None: + return + for key, orig in self._originals.items(): + try: + parts = key.split(".") + obj = self._client + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], orig) + except Exception: + log.warning("Could not restore %s", key) + self._client = None + self._originals.clear() + + def _wrap_sync(self, original: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> Any: + span, token = create_llm_span("anthropic.messages.create", kwargs, _CAPTURE_PARAMS) + if span is None: + return original(*args, **kwargs) + try: + response = original(*args, **kwargs) + finish_llm_span(span, token, response, _extract_output, _extract_response_meta) + return response + except Exception as exc: + fail_llm_span(span, token, exc) + raise + + return wrapped + + +def _extract_output(response: Any) -> Any: + try: + content = response.content + if content: + block = content[0] + return {"type": block.type, "text": getattr(block, "text", None)} + except (AttributeError, IndexError): + pass + return None + + +def _extract_response_meta(response: Any) -> Dict[str, Any]: + meta: Dict[str, Any] = {} + try: + usage = response.usage + if usage is not None: + meta["usage"] = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + except AttributeError: + pass + try: + meta["response_model"] = response.model + except AttributeError: + pass + try: + meta["stop_reason"] = response.stop_reason + except AttributeError: + pass + return meta + + +# --- Convenience API --- + +_provider_instance: Optional[AnthropicProvider] = None + + +def instrument_anthropic(client: Any) -> AnthropicProvider: + global _provider_instance + if _provider_instance is not None: + _provider_instance.disconnect() + _provider_instance = AnthropicProvider() + _provider_instance.connect_client(client) + return _provider_instance + + +def uninstrument_anthropic() -> None: + global _provider_instance + if _provider_instance is not None: + _provider_instance.disconnect() + _provider_instance = None diff --git a/src/layerlens/instrument/adapters/providers/litellm.py b/src/layerlens/instrument/adapters/providers/litellm.py new file mode 100644 index 0000000..9f12514 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/litellm.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any + +from .openai import _extract_output, _extract_response_meta +from ._base_provider import fail_llm_span, create_llm_span, finish_llm_span + +_CAPTURE_PARAMS = frozenset( + { + "model", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "response_format", + } +) + +_original_completion = None +_original_acompletion = None + + +def instrument_litellm() -> None: + try: + import litellm + except ImportError as err: + raise ImportError( + "The 'litellm' package is required for LiteLLM instrumentation. Install it with: pip install litellm" + ) from err + + global _original_completion, _original_acompletion + + if _original_completion is None: + _original_completion = litellm.completion + orig_sync = _original_completion + + def patched_completion(*args: Any, **kwargs: Any) -> Any: + span, token = create_llm_span("litellm.completion", kwargs, _CAPTURE_PARAMS) + if span is None: + return orig_sync(*args, **kwargs) + try: + response = orig_sync(*args, **kwargs) + finish_llm_span(span, token, response, _extract_output, _extract_response_meta) + return response + except Exception as exc: + fail_llm_span(span, token, exc) + raise + + litellm.completion = patched_completion + + if _original_acompletion is None: + _original_acompletion = litellm.acompletion + orig_async = _original_acompletion + + async def patched_acompletion(*args: Any, **kwargs: Any) -> Any: + span, token = create_llm_span("litellm.acompletion", kwargs, _CAPTURE_PARAMS) + if span is None: + return await orig_async(*args, **kwargs) + try: + response = await orig_async(*args, **kwargs) + finish_llm_span(span, token, response, _extract_output, _extract_response_meta) + return response + except Exception as exc: + fail_llm_span(span, token, exc) + raise + + litellm.acompletion = patched_acompletion + + +def uninstrument_litellm() -> None: + global _original_completion, _original_acompletion + try: + import litellm + except ImportError: + return + + if _original_completion is not None: + litellm.completion = _original_completion + _original_completion = None + if _original_acompletion is not None: + litellm.acompletion = _original_acompletion + _original_acompletion = None diff --git a/src/layerlens/instrument/adapters/providers/openai.py b/src/layerlens/instrument/adapters/providers/openai.py new file mode 100644 index 0000000..2ccd331 --- /dev/null +++ b/src/layerlens/instrument/adapters/providers/openai.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional + +from ._base_provider import fail_llm_span, create_llm_span, finish_llm_span + +log: logging.Logger = logging.getLogger(__name__) + +_CAPTURE_PARAMS = frozenset( + { + "model", + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "response_format", + "tool_choice", + } +) + + +class OpenAIProvider: + def __init__(self) -> None: + self._client: Any = None + self._originals: Dict[str, Any] = {} + + def connect_client(self, client: Any) -> Any: + self._client = client + + if hasattr(client, "chat") and hasattr(client.chat, "completions"): + orig = client.chat.completions.create + self._originals["chat.completions.create"] = orig + client.chat.completions.create = self._wrap_sync(orig) + + if hasattr(client.chat.completions, "acreate"): + async_orig = client.chat.completions.acreate + self._originals["chat.completions.acreate"] = async_orig + client.chat.completions.acreate = self._wrap_async(async_orig) + + return client + + def disconnect(self) -> None: + if self._client is None: + return + for key, orig in self._originals.items(): + try: + parts = key.split(".") + obj = self._client + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], orig) + except Exception: + log.warning("Could not restore %s", key) + self._client = None + self._originals.clear() + + def _wrap_sync(self, original: Any) -> Any: + def wrapped(*args: Any, **kwargs: Any) -> Any: + span, token = create_llm_span("openai.chat.completions.create", kwargs, _CAPTURE_PARAMS) + if span is None: + return original(*args, **kwargs) + try: + response = original(*args, **kwargs) + finish_llm_span(span, token, response, _extract_output, _extract_response_meta) + return response + except Exception as exc: + fail_llm_span(span, token, exc) + raise + + return wrapped + + def _wrap_async(self, original: Any) -> Any: + async def wrapped(*args: Any, **kwargs: Any) -> Any: + span, token = create_llm_span("openai.chat.completions.create", kwargs, _CAPTURE_PARAMS) + if span is None: + return await original(*args, **kwargs) + try: + response = await original(*args, **kwargs) + finish_llm_span(span, token, response, _extract_output, _extract_response_meta) + return response + except Exception as exc: + fail_llm_span(span, token, exc) + raise + + return wrapped + + +def _extract_output(response: Any) -> Any: + try: + choices = response.choices + if choices: + msg = choices[0].message + return {"role": msg.role, "content": msg.content} + except (AttributeError, IndexError): + pass + return None + + +def _extract_response_meta(response: Any) -> Dict[str, Any]: + meta: Dict[str, Any] = {} + try: + usage = response.usage + if usage is not None: + meta["usage"] = { + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + } + except AttributeError: + pass + try: + meta["response_model"] = response.model + except AttributeError: + pass + return meta + + +# --- Convenience API --- + +_provider_instance: Optional[OpenAIProvider] = None + + +def instrument_openai(client: Any) -> OpenAIProvider: + global _provider_instance + if _provider_instance is not None: + _provider_instance.disconnect() + _provider_instance = OpenAIProvider() + _provider_instance.connect_client(client) + return _provider_instance + + +def uninstrument_openai() -> None: + global _provider_instance + if _provider_instance is not None: + _provider_instance.disconnect() + _provider_instance = None diff --git a/tests/instrument/__init__.py b/tests/instrument/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/instrument/conftest.py b/tests/instrument/conftest.py new file mode 100644 index 0000000..0dda669 --- /dev/null +++ b/tests/instrument/conftest.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import json +from unittest.mock import Mock + +import pytest + + +@pytest.fixture +def mock_client(): + client = Mock() + client.traces = Mock() + client.traces.upload = Mock() + return client + + +@pytest.fixture +def capture_trace(mock_client): + uploaded = {} + + def _capture(path): + with open(path) as f: + uploaded["trace"] = json.load(f) + + mock_client.traces.upload.side_effect = _capture + return uploaded diff --git a/tests/instrument/test_adapters.py b/tests/instrument/test_adapters.py new file mode 100644 index 0000000..4a430d9 --- /dev/null +++ b/tests/instrument/test_adapters.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import sys +import types +import importlib +from uuid import uuid4 +from unittest.mock import Mock + + +class TestLangChainAdapter: + def _setup_langchain_mock(self): + mock_lc_core = types.ModuleType("langchain_core") + mock_lc_callbacks = types.ModuleType("langchain_core.callbacks") + + class FakeBaseCallbackHandler: + def __init__(self): + pass + + mock_lc_callbacks.BaseCallbackHandler = FakeBaseCallbackHandler + mock_lc_core.callbacks = mock_lc_callbacks + + sys.modules["langchain_core"] = mock_lc_core + sys.modules["langchain_core.callbacks"] = mock_lc_callbacks + + def _teardown_langchain_mock(self): + for key in list(sys.modules.keys()): + if key.startswith("langchain_core"): + del sys.modules[key] + + def _get_handler(self, mock_client, capture_trace): + from layerlens.instrument.adapters.frameworks import langchain as lc_mod + + importlib.reload(lc_mod) + return lc_mod.LangChainCallbackHandler(mock_client) + + def test_builds_span_tree(self, mock_client, capture_trace): + self._setup_langchain_mock() + try: + handler = self._get_handler(mock_client, capture_trace) + + chain_run_id = uuid4() + llm_run_id = uuid4() + + handler.on_chain_start( + {"name": "RunnableSequence", "id": ["RunnableSequence"]}, + {"question": "What is AI?"}, + run_id=chain_run_id, + ) + handler.on_llm_start( + {"name": "ChatOpenAI", "id": ["ChatOpenAI"]}, + ["What is AI?"], + run_id=llm_run_id, + parent_run_id=chain_run_id, + ) + + llm_response = Mock() + llm_response.generations = [[Mock(text="AI is...")]] + llm_response.llm_output = {"token_usage": {"total_tokens": 50}, "model_name": "gpt-4"} + handler.on_llm_end(llm_response, run_id=llm_run_id) + handler.on_chain_end({"output": "AI is..."}, run_id=chain_run_id) + + root = capture_trace["trace"][0] + assert root["name"] == "RunnableSequence" + assert root["kind"] == "chain" + assert len(root["children"]) == 1 + + llm = root["children"][0] + assert llm["name"] == "ChatOpenAI" + assert llm["kind"] == "llm" + assert llm["output"] == "AI is..." + assert llm["metadata"]["model"] == "gpt-4" + assert llm["metadata"]["usage"]["total_tokens"] == 50 + finally: + self._teardown_langchain_mock() + + def test_tracks_tools_and_retrievers(self, mock_client, capture_trace): + self._setup_langchain_mock() + try: + handler = self._get_handler(mock_client, capture_trace) + + chain_id = uuid4() + tool_id = uuid4() + retriever_id = uuid4() + + handler.on_chain_start({"name": "Agent"}, {"input": "test"}, run_id=chain_id) + handler.on_tool_start({"name": "search"}, "query", run_id=tool_id, parent_run_id=chain_id) + handler.on_tool_end("results", run_id=tool_id) + handler.on_retriever_start({"name": "vectorstore"}, "query", run_id=retriever_id, parent_run_id=chain_id) + + docs = [Mock(page_content="doc1", metadata={"source": "a"})] + handler.on_retriever_end(docs, run_id=retriever_id) + handler.on_chain_end({"output": "done"}, run_id=chain_id) + + root = capture_trace["trace"][0] + assert root["name"] == "Agent" + assert len(root["children"]) == 2 + assert root["children"][0]["kind"] == "tool" + assert root["children"][1]["kind"] == "retriever" + finally: + self._teardown_langchain_mock() + + def test_error_on_chain(self, mock_client, capture_trace): + self._setup_langchain_mock() + try: + handler = self._get_handler(mock_client, capture_trace) + + chain_id = uuid4() + handler.on_chain_start({"name": "FailChain"}, {"input": "x"}, run_id=chain_id) + handler.on_chain_error(ValueError("broke"), run_id=chain_id) + + root = capture_trace["trace"][0] + assert root["status"] == "error" + assert root["error"] == "broke" + finally: + self._teardown_langchain_mock() + + def test_null_serialized_handled(self, mock_client, capture_trace): + self._setup_langchain_mock() + try: + handler = self._get_handler(mock_client, capture_trace) + + run_id = uuid4() + handler.on_chain_start(None, {"input": "x"}, run_id=run_id) + handler.on_chain_end({"output": "done"}, run_id=run_id) + + root = capture_trace["trace"][0] + assert root["name"] == "unknown" + assert root["status"] == "ok" + finally: + self._teardown_langchain_mock() diff --git a/tests/instrument/test_core.py b/tests/instrument/test_core.py new file mode 100644 index 0000000..2b1fc00 --- /dev/null +++ b/tests/instrument/test_core.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import os + +import pytest + +from layerlens.instrument import SpanData, span, trace +from layerlens.instrument._context import _current_span, _current_recorder +from layerlens.instrument._recorder import TraceRecorder + + +class TestTraceDecorator: + def test_basic_trace(self, mock_client): + @trace(mock_client) + def my_func(x): + return x * 2 + + result = my_func(5) + assert result == 10 + mock_client.traces.upload.assert_called_once() + + def test_trace_with_custom_name(self, mock_client, capture_trace): + @trace(mock_client, name="custom_name") + def my_func(): + return "ok" + + my_func() + assert capture_trace["trace"][0]["name"] == "custom_name" + + def test_trace_captures_input(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(query): + return "result" + + my_func("hello") + assert capture_trace["trace"][0]["input"] == "hello" + + def test_trace_captures_output(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(): + return {"answer": 42} + + my_func() + assert capture_trace["trace"][0]["output"] == {"answer": 42} + + def test_trace_on_error(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(): + raise ValueError("boom") + + with pytest.raises(ValueError, match="boom"): + my_func() + + assert capture_trace["trace"][0]["status"] == "error" + assert capture_trace["trace"][0]["error"] == "boom" + + def test_trace_cleans_up_context(self, mock_client): + @trace(mock_client) + def my_func(): + return "ok" + + my_func() + assert _current_recorder.get() is None + assert _current_span.get() is None + + def test_trace_cleans_up_context_on_error(self, mock_client): + @trace(mock_client) + def my_func(): + raise RuntimeError("fail") + + with pytest.raises(RuntimeError): + my_func() + + assert _current_recorder.get() is None + assert _current_span.get() is None + + +class TestSpanContextManager: + def test_span_creates_child(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(): + with span("child_span", kind="llm") as s: + s.output = "child output" + return "done" + + my_func() + root = capture_trace["trace"][0] + assert len(root["children"]) == 1 + child = root["children"][0] + assert child["name"] == "child_span" + assert child["kind"] == "llm" + assert child["output"] == "child output" + assert child["parent_id"] == root["span_id"] + + def test_nested_spans(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(): + with span("outer", kind="chain") as s1: + s1.output = "outer" + with span("inner", kind="llm") as s2: + s2.output = "inner" + return "done" + + my_func() + root = capture_trace["trace"][0] + outer = root["children"][0] + assert outer["name"] == "outer" + inner = outer["children"][0] + assert inner["name"] == "inner" + assert inner["parent_id"] == outer["span_id"] + + def test_span_on_error(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(): + try: + with span("failing") as s: + raise ValueError("span error") + except ValueError: + pass + return "recovered" + + my_func() + child = capture_trace["trace"][0]["children"][0] + assert child["status"] == "error" + assert child["error"] == "span error" + + def test_span_without_trace_noops(self): + with span("orphan", kind="llm") as s: + s.output = "test" + assert s.output == "test" + + def test_multiple_sibling_spans(self, mock_client, capture_trace): + @trace(mock_client) + def my_func(): + with span("retrieve", kind="retriever") as s: + s.output = ["doc1", "doc2"] + with span("generate", kind="llm") as s: + s.output = "answer" + return "done" + + my_func() + root = capture_trace["trace"][0] + assert len(root["children"]) == 2 + assert root["children"][0]["name"] == "retrieve" + assert root["children"][1]["name"] == "generate" + + +class TestTraceRecorder: + def test_flush_calls_upload(self, mock_client): + recorder = TraceRecorder(mock_client) + recorder.root = SpanData(name="root") + recorder.root.finish() + + recorder.flush() + mock_client.traces.upload.assert_called_once() + + path = mock_client.traces.upload.call_args[0][0] + assert not os.path.exists(path) + + def test_flush_noop_without_root(self, mock_client): + recorder = TraceRecorder(mock_client) + recorder.flush() + mock_client.traces.upload.assert_not_called() diff --git a/tests/instrument/test_providers.py b/tests/instrument/test_providers.py new file mode 100644 index 0000000..fceeb1d --- /dev/null +++ b/tests/instrument/test_providers.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import sys +import types +from unittest.mock import Mock + +from layerlens.instrument import trace + + +def _openai_response(): + r = Mock() + r.choices = [Mock()] + r.choices[0].message = Mock() + r.choices[0].message.role = "assistant" + r.choices[0].message.content = "Hello!" + r.usage = Mock() + r.usage.prompt_tokens = 10 + r.usage.completion_tokens = 5 + r.usage.total_tokens = 15 + r.model = "gpt-4" + return r + + +def _anthropic_response(): + r = Mock() + block = Mock() + block.type = "text" + block.text = "I'm Claude!" + r.content = [block] + r.usage = Mock() + r.usage.input_tokens = 20 + r.usage.output_tokens = 10 + r.model = "claude-3-opus" + r.stop_reason = "end_turn" + return r + + +class TestOpenAIProvider: + def test_instrument_creates_span(self, mock_client, capture_trace): + from layerlens.instrument.adapters.providers.openai import OpenAIProvider + + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=_openai_response()) + + provider = OpenAIProvider() + provider.connect_client(openai_client) + + @trace(mock_client) + def my_agent(): + return ( + openai_client.chat.completions.create(model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) + .choices[0] + .message.content + ) + + my_agent() + llm = capture_trace["trace"][0]["children"][0] + assert llm["kind"] == "llm" + assert llm["name"] == "openai.chat.completions.create" + assert llm["metadata"]["model"] == "gpt-4" + assert llm["metadata"]["usage"]["total_tokens"] == 15 + assert llm["output"]["content"] == "Hello!" + + def test_passthrough_without_trace(self): + from layerlens.instrument.adapters.providers.openai import OpenAIProvider + + openai_client = Mock() + openai_client.chat.completions.create = Mock(return_value=_openai_response()) + + provider = OpenAIProvider() + provider.connect_client(openai_client) + + result = openai_client.chat.completions.create(model="gpt-4", messages=[]) + assert result.choices[0].message.content == "Hello!" + + def test_disconnect_restores(self): + from layerlens.instrument.adapters.providers.openai import OpenAIProvider + + openai_client = Mock() + original = openai_client.chat.completions.create + + provider = OpenAIProvider() + provider.connect_client(openai_client) + assert openai_client.chat.completions.create is not original + + provider.disconnect() + assert openai_client.chat.completions.create is original + + def test_instrument_convenience_function(self): + from layerlens.instrument.adapters.providers.openai import instrument_openai, uninstrument_openai + + openai_client = Mock() + original = openai_client.chat.completions.create + instrument_openai(openai_client) + assert openai_client.chat.completions.create is not original + uninstrument_openai() + + +class TestAnthropicProvider: + def test_instrument_creates_span(self, mock_client, capture_trace): + from layerlens.instrument.adapters.providers.anthropic import AnthropicProvider + + anthropic_client = Mock() + anthropic_client.messages.create = Mock(return_value=_anthropic_response()) + + provider = AnthropicProvider() + provider.connect_client(anthropic_client) + + @trace(mock_client) + def my_agent(): + return ( + anthropic_client.messages.create( + model="claude-3-opus", max_tokens=1024, messages=[{"role": "user", "content": "Hi"}] + ) + .content[0] + .text + ) + + my_agent() + llm = capture_trace["trace"][0]["children"][0] + assert llm["kind"] == "llm" + assert llm["name"] == "anthropic.messages.create" + assert llm["output"]["text"] == "I'm Claude!" + assert llm["metadata"]["usage"]["input_tokens"] == 20 + assert llm["metadata"]["response_model"] == "claude-3-opus" + assert llm["metadata"]["stop_reason"] == "end_turn" + + def test_disconnect_restores(self): + from layerlens.instrument.adapters.providers.anthropic import AnthropicProvider + + anthropic_client = Mock() + original = anthropic_client.messages.create + + provider = AnthropicProvider() + provider.connect_client(anthropic_client) + provider.disconnect() + assert anthropic_client.messages.create is original + + +class TestLiteLLMProvider: + def setup_method(self): + self.mock_litellm = types.ModuleType("litellm") + self.mock_litellm.completion = Mock(return_value=_openai_response()) + self.mock_litellm.acompletion = Mock() + sys.modules["litellm"] = self.mock_litellm + + def teardown_method(self): + for key in list(sys.modules.keys()): + if key.startswith("litellm"): + del sys.modules[key] + from layerlens.instrument.adapters.providers import litellm as litellm_adapter + + litellm_adapter._original_completion = None + litellm_adapter._original_acompletion = None + + def test_instrument_creates_span(self, mock_client, capture_trace): + from layerlens.instrument.adapters.providers.litellm import instrument_litellm + + instrument_litellm() + + @trace(mock_client) + def my_agent(): + import litellm + + return ( + litellm.completion(model="gpt-4", messages=[{"role": "user", "content": "Hi"}]) + .choices[0] + .message.content + ) + + my_agent() + llm = capture_trace["trace"][0]["children"][0] + assert llm["kind"] == "llm" + assert llm["name"] == "litellm.completion" + assert llm["metadata"]["model"] == "gpt-4" + + def test_passthrough_without_trace(self): + from layerlens.instrument.adapters.providers.litellm import instrument_litellm + + instrument_litellm() + import litellm + + result = litellm.completion(model="gpt-4", messages=[]) + assert result.choices[0].message.content == "Hello!" + + def test_uninstrument(self): + from layerlens.instrument.adapters.providers.litellm import instrument_litellm, uninstrument_litellm + + original = self.mock_litellm.completion + instrument_litellm() + assert self.mock_litellm.completion is not original + uninstrument_litellm() + assert self.mock_litellm.completion is original + + +class TestProviderErrorHandling: + def test_span_captures_error(self, mock_client, capture_trace): + from layerlens.instrument.adapters.providers.openai import OpenAIProvider + + openai_client = Mock() + openai_client.chat.completions.create = Mock(side_effect=RuntimeError("API error")) + + provider = OpenAIProvider() + provider.connect_client(openai_client) + + @trace(mock_client) + def my_agent(): + try: + openai_client.chat.completions.create(model="gpt-4", messages=[]) + except RuntimeError: + pass + return "recovered" + + my_agent() + llm = capture_trace["trace"][0]["children"][0] + assert llm["status"] == "error" + assert llm["error"] == "API error" diff --git a/tests/instrument/test_types.py b/tests/instrument/test_types.py new file mode 100644 index 0000000..272edb3 --- /dev/null +++ b/tests/instrument/test_types.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import time + +from layerlens.instrument._types import SpanData + + +class TestSpanData: + def test_defaults(self): + s = SpanData(name="test") + assert s.name == "test" + assert len(s.span_id) == 16 + assert s.parent_id is None + assert s.status == "ok" + assert s.kind == "internal" + assert s.input is None + assert s.output is None + assert s.error is None + assert s.metadata == {} + assert s.children == [] + assert s.end_time is None + assert s.start_time <= time.time() + + def test_finish_ok(self): + s = SpanData(name="test") + s.finish() + assert s.end_time is not None + assert s.status == "ok" + assert s.error is None + + def test_finish_error(self): + s = SpanData(name="test") + s.finish(error="something broke") + assert s.end_time is not None + assert s.status == "error" + assert s.error == "something broke" + + def test_to_dict(self): + parent = SpanData(name="parent") + child = SpanData(name="child", parent_id=parent.span_id) + parent.children.append(child) + + d = parent.to_dict() + assert d["name"] == "parent" + assert d["parent_id"] is None + assert len(d["children"]) == 1 + assert d["children"][0]["name"] == "child" + assert d["children"][0]["parent_id"] == parent.span_id + + def test_to_dict_nested(self): + root = SpanData(name="root") + child1 = SpanData(name="c1", parent_id=root.span_id) + child2 = SpanData(name="c2", parent_id=child1.span_id) + root.children.append(child1) + child1.children.append(child2) + + d = root.to_dict() + assert d["children"][0]["children"][0]["name"] == "c2" From 9cfb986068c80360dad6a9faf868312a502bc35d Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Thu, 19 Mar 2026 21:54:17 -0700 Subject: [PATCH 2/3] fix: scope ARG ignore to just instrument tests --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 30e15da..16cb9a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,7 @@ known-first-party = ["openai", "tests"] "bin/**.py" = ["T201", "T203"] "scripts/**.py" = ["T201", "T203"] "tests/**.py" = ["T201", "T203"] +"tests/instrument/**.py" = ["T201", "T203", "ARG"] "examples/**.py" = ["T201", "T203"] "src/layerlens/cli/**" = ["T201", "T203"] "src/layerlens/instrument/adapters/frameworks/langchain.py" = ["ARG002"] From a4e75b0f2a07d67843b16d393a356c5a5458820c Mon Sep 17 00:00:00 2001 From: Gary <59334078+garrettallen14@users.noreply.github.com> Date: Thu, 19 Mar 2026 22:08:41 -0700 Subject: [PATCH 3/3] fix: resolve mypy errors in langchain and litellm adapters --- requirements-dev.lock | 130 +++++++++++++++ requirements.lock | 150 ++++++++++++++++++ .../adapters/frameworks/langchain.py | 2 +- .../instrument/adapters/providers/litellm.py | 4 +- 4 files changed, 283 insertions(+), 3 deletions(-) diff --git a/requirements-dev.lock b/requirements-dev.lock index 81a18f2..55fc02e 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -10,10 +10,28 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-doc==0.0.4 + # via typer annotated-types==0.7.0 # via pydantic +anthropic==0.86.0 + # via layerlens anyio==4.9.0 + # via anthropic # via httpx + # via openai +async-timeout==5.0.1 + # via aiohttp +attrs==26.1.0 + # via aiohttp + # via jsonschema + # via referencing backports-tarfile==1.2.0 # via jaraco-context build==1.3.0 @@ -27,30 +45,57 @@ charset-normalizer==3.4.3 # via requests click==8.1.8 # via layerlens + # via litellm + # via typer coverage==7.10.2 # via pytest-cov cryptography==46.0.5 # via secretstorage +distro==1.9.0 + # via anthropic + # via openai +docstring-parser==0.17.0 + # via anthropic docutils==0.22 # via readme-renderer exceptiongroup==1.3.0 # via anyio # via pytest +fastuuid==0.14.0 + # via litellm +filelock==3.19.1 + # via huggingface-hub +frozenlist==1.8.0 + # via aiohttp + # via aiosignal +fsspec==2025.10.0 + # via huggingface-hub h11==0.16.0 # via httpcore +hf-xet==1.4.2 + # via huggingface-hub httpcore==1.0.9 # via httpx httpx==0.28.1 + # via anthropic + # via huggingface-hub + # via langsmith # via layerlens + # via litellm + # via openai +huggingface-hub==1.7.2 + # via tokenizers id==1.5.0 # via twine idna==3.10 # via anyio # via httpx # via requests + # via yarl importlib-metadata==8.7.0 # via build # via keyring + # via litellm # via twine iniconfig==2.1.0 # via pytest @@ -63,15 +108,39 @@ jaraco-functools==4.2.1 jeepney==0.9.0 # via keyring # via secretstorage +jinja2==3.1.6 + # via litellm +jiter==0.13.0 + # via anthropic + # via openai +jsonpatch==1.33 + # via langchain-core +jsonpointer==3.0.0 + # via jsonpatch +jsonschema==4.25.1 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema keyring==25.6.0 # via twine +langchain-core==0.3.83 + # via layerlens +langsmith==0.4.37 + # via langchain-core +litellm==1.82.6 + # via layerlens markdown-it-py==3.0.0 # via rich +markupsafe==3.0.3 + # via jinja2 mdurl==0.1.2 # via markdown-it-py more-itertools==10.7.0 # via jaraco-classes # via jaraco-functools +multidict==6.7.1 + # via aiohttp + # via yarl mypy==1.17.0 mypy-extensions==1.1.0 # via mypy @@ -79,8 +148,16 @@ nh3==0.3.0 # via readme-renderer nodeenv==1.9.1 # via pyright +openai==2.29.0 + # via layerlens + # via litellm +orjson==3.11.5 + # via langsmith packaging==25.0 # via build + # via huggingface-hub + # via langchain-core + # via langsmith # via pytest # via twine pathspec==0.12.1 @@ -88,10 +165,18 @@ pathspec==0.12.1 pluggy==1.6.0 # via pytest # via pytest-cov +propcache==0.4.1 + # via aiohttp + # via yarl pycparser==2.23 # via cffi pydantic==2.11.7 + # via anthropic + # via langchain-core + # via langsmith # via layerlens + # via litellm + # via openai pydantic-core==2.33.2 # via pydantic pygments==2.19.2 @@ -104,42 +189,87 @@ pyright==1.1.399 pytest==8.4.1 # via pytest-cov pytest-cov==6.2.1 +python-dotenv==1.2.1 + # via litellm +pyyaml==6.0.3 + # via huggingface-hub + # via langchain-core readme-renderer==44.0 # via twine +referencing==0.36.2 + # via jsonschema + # via jsonschema-specifications +regex==2026.1.15 + # via tiktoken requests==2.32.5 # via id + # via langsmith # via requests-toolbelt + # via tiktoken # via twine requests-toolbelt==1.0.0 + # via langsmith # via twine rfc3986==2.0.0 # via twine rich==14.1.0 # via twine + # via typer +rpds-py==0.27.1 + # via jsonschema + # via referencing ruff==0.12.7 secretstorage==3.3.3 # via keyring +shellingham==1.5.4 + # via typer sniffio==1.3.1 + # via anthropic # via anyio + # via openai +tenacity==9.1.2 + # via langchain-core +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via litellm tomli==2.2.1 # via build # via coverage # via mypy # via pytest +tqdm==4.67.3 + # via huggingface-hub + # via openai twine==6.1.0 +typer==0.23.2 + # via huggingface-hub typing-extensions==4.14.1 + # via aiosignal + # via anthropic # via anyio # via cryptography # via exceptiongroup + # via huggingface-hub + # via langchain-core + # via multidict # via mypy + # via openai # via pydantic # via pydantic-core # via pyright + # via referencing # via typing-inspection typing-inspection==0.4.1 # via pydantic urllib3==2.5.0 # via requests # via twine +uuid-utils==0.14.1 + # via langchain-core +yarl==1.22.0 + # via aiohttp zipp==3.23.0 # via importlib-metadata +zstandard==0.25.0 + # via langsmith diff --git a/requirements.lock b/requirements.lock index 1a890c9..5b9bb3d 100644 --- a/requirements.lock +++ b/requirements.lock @@ -10,37 +10,187 @@ # universal: false -e file:. +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.13.3 + # via litellm +aiosignal==1.4.0 + # via aiohttp +annotated-doc==0.0.4 + # via typer annotated-types==0.7.0 # via pydantic +anthropic==0.86.0 + # via layerlens anyio==4.9.0 + # via anthropic # via httpx + # via openai +async-timeout==5.0.1 + # via aiohttp +attrs==26.1.0 + # via aiohttp + # via jsonschema + # via referencing certifi==2025.7.14 # via httpcore # via httpx + # via requests +charset-normalizer==3.4.6 + # via requests click==8.1.8 # via layerlens + # via litellm + # via typer +distro==1.9.0 + # via anthropic + # via openai +docstring-parser==0.17.0 + # via anthropic exceptiongroup==1.3.0 # via anyio +fastuuid==0.14.0 + # via litellm +filelock==3.19.1 + # via huggingface-hub +frozenlist==1.8.0 + # via aiohttp + # via aiosignal +fsspec==2025.10.0 + # via huggingface-hub h11==0.16.0 # via httpcore +hf-xet==1.4.2 + # via huggingface-hub httpcore==1.0.9 # via httpx httpx==0.28.1 + # via anthropic + # via huggingface-hub + # via langsmith # via layerlens + # via litellm + # via openai +huggingface-hub==1.7.2 + # via tokenizers idna==3.10 # via anyio # via httpx + # via requests + # via yarl +importlib-metadata==8.7.1 + # via litellm +jinja2==3.1.6 + # via litellm +jiter==0.13.0 + # via anthropic + # via openai +jsonpatch==1.33 + # via langchain-core +jsonpointer==3.0.0 + # via jsonpatch +jsonschema==4.25.1 + # via litellm +jsonschema-specifications==2025.9.1 + # via jsonschema +langchain-core==0.3.83 + # via layerlens +langsmith==0.4.37 + # via langchain-core +litellm==1.82.6 + # via layerlens +markdown-it-py==3.0.0 + # via rich +markupsafe==3.0.3 + # via jinja2 +mdurl==0.1.2 + # via markdown-it-py +multidict==6.7.1 + # via aiohttp + # via yarl +openai==2.29.0 + # via layerlens + # via litellm +orjson==3.11.5 + # via langsmith +packaging==25.0 + # via huggingface-hub + # via langchain-core + # via langsmith +propcache==0.4.1 + # via aiohttp + # via yarl pydantic==2.11.7 + # via anthropic + # via langchain-core + # via langsmith # via layerlens + # via litellm + # via openai pydantic-core==2.33.2 # via pydantic +pygments==2.19.2 + # via rich +python-dotenv==1.2.1 + # via litellm +pyyaml==6.0.3 + # via huggingface-hub + # via langchain-core +referencing==0.36.2 + # via jsonschema + # via jsonschema-specifications +regex==2026.1.15 + # via tiktoken +requests==2.32.5 + # via langsmith + # via requests-toolbelt + # via tiktoken +requests-toolbelt==1.0.0 + # via langsmith +rich==14.3.3 + # via typer +rpds-py==0.27.1 + # via jsonschema + # via referencing +shellingham==1.5.4 + # via typer sniffio==1.3.1 + # via anthropic # via anyio + # via openai +tenacity==9.1.2 + # via langchain-core +tiktoken==0.12.0 + # via litellm +tokenizers==0.22.2 + # via litellm +tqdm==4.67.3 + # via huggingface-hub + # via openai +typer==0.23.2 + # via huggingface-hub typing-extensions==4.14.1 + # via aiosignal + # via anthropic # via anyio # via exceptiongroup + # via huggingface-hub + # via langchain-core + # via multidict + # via openai # via pydantic # via pydantic-core + # via referencing # via typing-inspection typing-inspection==0.4.1 # via pydantic +urllib3==2.6.3 + # via requests +uuid-utils==0.14.1 + # via langchain-core +yarl==1.22.0 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata +zstandard==0.25.0 + # via langsmith diff --git a/src/layerlens/instrument/adapters/frameworks/langchain.py b/src/layerlens/instrument/adapters/frameworks/langchain.py index 5a213a9..1e30ee6 100644 --- a/src/layerlens/instrument/adapters/frameworks/langchain.py +++ b/src/layerlens/instrument/adapters/frameworks/langchain.py @@ -9,7 +9,7 @@ from langchain_core.callbacks import BaseCallbackHandler # pyright: ignore[reportAssignmentType] except ImportError: - class BaseCallbackHandler: # type: ignore[no-redef,misc] + class BaseCallbackHandler: # type: ignore[no-redef] def __init_subclass__(cls, **kwargs: Any) -> None: raise ImportError( "The 'langchain-core' package is required for LangChain instrumentation. " diff --git a/src/layerlens/instrument/adapters/providers/litellm.py b/src/layerlens/instrument/adapters/providers/litellm.py index 9f12514..f84497c 100644 --- a/src/layerlens/instrument/adapters/providers/litellm.py +++ b/src/layerlens/instrument/adapters/providers/litellm.py @@ -17,8 +17,8 @@ } ) -_original_completion = None -_original_acompletion = None +_original_completion: Any = None +_original_acompletion: Any = None def instrument_litellm() -> None: