diff --git a/src/agentex/lib/core/tracing/__init__.py b/src/agentex/lib/core/tracing/__init__.py index 9f91f9cec..639f3ba8e 100644 --- a/src/agentex/lib/core/tracing/__init__.py +++ b/src/agentex/lib/core/tracing/__init__.py @@ -1,5 +1,19 @@ from agentex.types.span import Span from agentex.lib.core.tracing.trace import Trace, AsyncTrace from agentex.lib.core.tracing.tracer import Tracer, AsyncTracer +from agentex.lib.core.tracing.span_queue import ( + AsyncSpanQueue, + get_default_span_queue, + shutdown_default_span_queue, +) -__all__ = ["Trace", "AsyncTrace", "Span", "Tracer", "AsyncTracer"] +__all__ = [ + "Trace", + "AsyncTrace", + "Span", + "Tracer", + "AsyncTracer", + "AsyncSpanQueue", + "get_default_span_queue", + "shutdown_default_span_queue", +] diff --git a/src/agentex/lib/core/tracing/span_queue.py b/src/agentex/lib/core/tracing/span_queue.py new file mode 100644 index 000000000..e881cc1da --- /dev/null +++ b/src/agentex/lib/core/tracing/span_queue.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import asyncio +from enum import Enum +from dataclasses import dataclass + +from agentex.types.span import Span +from agentex.lib.utils.logging import make_logger +from agentex.lib.core.tracing.processors.tracing_processor_interface import ( + AsyncTracingProcessor, +) + +logger = make_logger(__name__) + + +class SpanEventType(str, Enum): + START = "start" + END = "end" + + +@dataclass +class _SpanQueueItem: + event_type: SpanEventType + span: Span + processors: list[AsyncTracingProcessor] + + +class AsyncSpanQueue: + """Background FIFO queue for async span processing. + + Span events are enqueued synchronously (non-blocking) and processed + sequentially by a background drain task. This keeps tracing HTTP calls + off the critical request path while preserving start-before-end ordering. + """ + + def __init__(self) -> None: + self._queue: asyncio.Queue[_SpanQueueItem] = asyncio.Queue() + self._drain_task: asyncio.Task[None] | None = None + self._stopping = False + + def enqueue( + self, + event_type: SpanEventType, + span: Span, + processors: list[AsyncTracingProcessor], + ) -> None: + if self._stopping: + logger.warning("Span queue is shutting down, dropping %s event for span %s", event_type.value, span.id) + return + self._ensure_drain_running() + self._queue.put_nowait(_SpanQueueItem(event_type=event_type, span=span, processors=processors)) + + def _ensure_drain_running(self) -> None: + if self._drain_task is None or self._drain_task.done(): + self._drain_task = asyncio.create_task(self._drain_loop()) + + async def _drain_loop(self) -> None: + while True: + item = await self._queue.get() + try: + if item.event_type == SpanEventType.START: + coros = [p.on_span_start(item.span) for p in item.processors] + else: + coros = [p.on_span_end(item.span) for p in item.processors] + results = await asyncio.gather(*coros, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + logger.error( + "Tracing processor error during %s for span %s", + item.event_type.value, + item.span.id, + exc_info=result, + ) + except Exception: + logger.exception("Unexpected error in span queue drain loop for span %s", item.span.id) + finally: + self._queue.task_done() + + async def shutdown(self, timeout: float = 30.0) -> None: + self._stopping = True + if self._queue.empty() and (self._drain_task is None or self._drain_task.done()): + return + try: + await asyncio.wait_for(self._queue.join(), timeout=timeout) + except asyncio.TimeoutError: + logger.warning( + "Span queue shutdown timed out after %.1fs with %d items remaining", timeout, self._queue.qsize() + ) + if self._drain_task is not None and not self._drain_task.done(): + self._drain_task.cancel() + try: + await self._drain_task + except asyncio.CancelledError: + pass + + +_default_span_queue: AsyncSpanQueue | None = None + + +def get_default_span_queue() -> AsyncSpanQueue: + global _default_span_queue + if _default_span_queue is None: + _default_span_queue = AsyncSpanQueue() + return _default_span_queue + + +async def shutdown_default_span_queue(timeout: float = 30.0) -> None: + global _default_span_queue + if _default_span_queue is not None: + await _default_span_queue.shutdown(timeout=timeout) + _default_span_queue = None diff --git a/src/agentex/lib/core/tracing/trace.py b/src/agentex/lib/core/tracing/trace.py index 2ba1d489e..7925df7fc 100644 --- a/src/agentex/lib/core/tracing/trace.py +++ b/src/agentex/lib/core/tracing/trace.py @@ -1,7 +1,6 @@ from __future__ import annotations import uuid -import asyncio from typing import Any, AsyncGenerator from datetime import UTC, datetime from contextlib import contextmanager, asynccontextmanager @@ -12,6 +11,11 @@ from agentex.types.span import Span from agentex.lib.utils.logging import make_logger from agentex.lib.utils.model_utils import recursive_model_dump +from agentex.lib.core.tracing.span_queue import ( + SpanEventType, + AsyncSpanQueue, + get_default_span_queue, +) from agentex.lib.core.tracing.processors.tracing_processor_interface import ( SyncTracingProcessor, AsyncTracingProcessor, @@ -173,6 +177,7 @@ def __init__( processors: list[AsyncTracingProcessor], client: AsyncAgentex, trace_id: str | None = None, + span_queue: AsyncSpanQueue | None = None, ): """ Initialize a new trace with the specified trace ID. @@ -180,10 +185,12 @@ def __init__( Args: trace_id: Required trace ID to use for this trace. processors: Optional list of tracing processors to use for this trace. + span_queue: Optional span queue for background processing. """ self.processors = processors self.client = client self.trace_id = trace_id + self._span_queue = span_queue or get_default_span_queue() async def start_span( self, @@ -225,9 +232,7 @@ async def start_span( ) if self.processors: - await asyncio.gather( - *[processor.on_span_start(span) for processor in self.processors] - ) + self._span_queue.enqueue(SpanEventType.START, span.model_copy(deep=True), self.processors) return span @@ -252,9 +257,7 @@ async def end_span( span.data = recursive_model_dump(span.data) if span.data else None if self.processors: - await asyncio.gather( - *[processor.on_span_end(span) for processor in self.processors] - ) + self._span_queue.enqueue(SpanEventType.END, span.model_copy(deep=True), self.processors) return span diff --git a/src/agentex/lib/core/tracing/tracer.py b/src/agentex/lib/core/tracing/tracer.py index da77bec95..3af79977e 100644 --- a/src/agentex/lib/core/tracing/tracer.py +++ b/src/agentex/lib/core/tracing/tracer.py @@ -2,6 +2,7 @@ from agentex import Agentex, AsyncAgentex from agentex.lib.core.tracing.trace import Trace, AsyncTrace +from agentex.lib.core.tracing.span_queue import AsyncSpanQueue from agentex.lib.core.tracing.tracing_processor_manager import ( get_sync_tracing_processors, get_async_tracing_processors, @@ -55,12 +56,13 @@ def __init__(self, client: AsyncAgentex): """ self.client = client - def trace(self, trace_id: str | None = None) -> AsyncTrace: + def trace(self, trace_id: str | None = None, span_queue: AsyncSpanQueue | None = None) -> AsyncTrace: """ Create a new trace with the given trace ID. Args: trace_id: The trace ID to use. + span_queue: Optional span queue for background processing. Returns: A new AsyncTrace instance. @@ -69,4 +71,5 @@ def trace(self, trace_id: str | None = None) -> AsyncTrace: processors=get_async_tracing_processors(), client=self.client, trace_id=trace_id, + span_queue=span_queue, ) diff --git a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py index b625eaa1c..56507ade5 100644 --- a/src/agentex/lib/sdk/fastacp/base/base_acp_server.py +++ b/src/agentex/lib/sdk/fastacp/base/base_acp_server.py @@ -32,6 +32,7 @@ from agentex.lib.environment_variables import EnvironmentVariables, refreshed_environment_variables from agentex.types.task_message_update import TaskMessageUpdate, StreamTaskMessageFull from agentex.types.task_message_content import TaskMessageContent +from agentex.lib.core.tracing.span_queue import shutdown_default_span_queue from agentex.lib.sdk.fastacp.base.constants import ( FASTACP_HEADER_SKIP_EXACT, FASTACP_HEADER_SKIP_PREFIXES, @@ -103,7 +104,10 @@ async def lifespan_context(app: FastAPI): # noqa: ARG001 else: logger.warning("AGENTEX_BASE_URL not set, skipping agent registration") - yield + try: + yield + finally: + await shutdown_default_span_queue() return lifespan_context diff --git a/tests/lib/core/tracing/test_span_queue.py b/tests/lib/core/tracing/test_span_queue.py new file mode 100644 index 000000000..1f39fb25d --- /dev/null +++ b/tests/lib/core/tracing/test_span_queue.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import time +import uuid +import asyncio +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch + +from agentex.types.span import Span +from agentex.lib.core.tracing.span_queue import SpanEventType, AsyncSpanQueue + + +def _make_span(span_id: str | None = None) -> Span: + return Span( + id=span_id or str(uuid.uuid4()), + name="test-span", + start_time=datetime.now(UTC), + trace_id="trace-1", + ) + + +def _make_processor(**overrides: AsyncMock) -> AsyncMock: + proc = AsyncMock() + proc.on_span_start = overrides.get("on_span_start", AsyncMock()) + proc.on_span_end = overrides.get("on_span_end", AsyncMock()) + return proc + + +class TestAsyncSpanQueueNonBlocking: + async def test_enqueue_does_not_block(self): + started = asyncio.Event() + + async def slow_start(span: Span) -> None: + started.set() + await asyncio.sleep(1.0) + + slow_processor = _make_processor( + on_span_start=AsyncMock(side_effect=slow_start), + ) + queue = AsyncSpanQueue() + span = _make_span() + + start = time.monotonic() + queue.enqueue(SpanEventType.START, span, [slow_processor]) + elapsed = time.monotonic() - start + + assert elapsed < 0.01, f"enqueue took {elapsed:.3f}s — should be instant" + + # Wait for the processor to start (proves it was called) + await asyncio.wait_for(started.wait(), timeout=2.0) + await queue.shutdown() + + +class TestAsyncSpanQueueOrdering: + async def test_fifo_ordering_preserved(self): + call_log: list[tuple[str, str]] = [] + + async def record_start(span: Span) -> None: + call_log.append(("start", span.id)) + + async def record_end(span: Span) -> None: + call_log.append(("end", span.id)) + + proc = _make_processor( + on_span_start=AsyncMock(side_effect=record_start), + on_span_end=AsyncMock(side_effect=record_end), + ) + queue = AsyncSpanQueue() + + span_a = _make_span("span-a") + span_b = _make_span("span-b") + + queue.enqueue(SpanEventType.START, span_a, [proc]) + queue.enqueue(SpanEventType.END, span_a, [proc]) + queue.enqueue(SpanEventType.START, span_b, [proc]) + queue.enqueue(SpanEventType.END, span_b, [proc]) + + await queue.shutdown() + + assert call_log == [ + ("start", "span-a"), + ("end", "span-a"), + ("start", "span-b"), + ("end", "span-b"), + ] + + +class TestAsyncSpanQueueErrorHandling: + async def test_error_in_processor_does_not_stop_drain(self): + call_count = 0 + + async def failing_start(span: Span) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("simulated failure") + + proc = _make_processor( + on_span_start=AsyncMock(side_effect=failing_start), + ) + queue = AsyncSpanQueue() + + queue.enqueue(SpanEventType.START, _make_span(), [proc]) + queue.enqueue(SpanEventType.START, _make_span(), [proc]) + + await queue.shutdown() + + assert call_count == 2, "Second event should still be processed after first fails" + + +class TestAsyncSpanQueueShutdown: + async def test_shutdown_drains_remaining_items(self): + processed: list[str] = [] + + async def track(span: Span) -> None: + processed.append(span.id) + + proc = _make_processor(on_span_start=AsyncMock(side_effect=track)) + queue = AsyncSpanQueue() + + for i in range(5): + queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc]) + + await queue.shutdown() + + assert len(processed) == 5 + + async def test_shutdown_timeout(self): + async def stuck_start(span: Span) -> None: + await asyncio.sleep(60) + + stuck_processor = _make_processor( + on_span_start=AsyncMock(side_effect=stuck_start), + ) + queue = AsyncSpanQueue() + queue.enqueue(SpanEventType.START, _make_span(), [stuck_processor]) + + # Give the drain loop a moment to pick up the item + await asyncio.sleep(0.05) + + start = time.monotonic() + await queue.shutdown(timeout=0.1) + elapsed = time.monotonic() - start + + assert elapsed < 1.0, f"shutdown should not hang — took {elapsed:.1f}s" + + async def test_enqueue_after_shutdown_is_dropped(self): + proc = _make_processor() + queue = AsyncSpanQueue() + await queue.shutdown() + + queue.enqueue(SpanEventType.START, _make_span(), [proc]) + + proc.on_span_start.assert_not_called() + + +class TestAsyncSpanQueueIntegration: + async def test_integration_with_async_trace(self): + call_log: list[tuple[str, str]] = [] + + async def record_start(span: Span) -> None: + call_log.append(("start", span.id)) + + async def record_end(span: Span) -> None: + call_log.append(("end", span.id)) + + proc = _make_processor( + on_span_start=AsyncMock(side_effect=record_start), + on_span_end=AsyncMock(side_effect=record_end), + ) + queue = AsyncSpanQueue() + + # Patch get_async_tracing_processors to return our mock + with patch( + "agentex.lib.core.tracing.trace.get_default_span_queue", + return_value=queue, + ): + from agentex.lib.core.tracing.trace import AsyncTrace + + mock_client = MagicMock() + trace = AsyncTrace( + processors=[proc], + client=mock_client, + trace_id="test-trace", + span_queue=queue, + ) + + async with trace.span("test-operation") as span: + output: dict[str, object] = {"result": "ok"} + span.output = output + + await queue.shutdown() + + assert len(call_log) == 2 + assert call_log[0][0] == "start" + assert call_log[1][0] == "end" + # Same span ID for both events + assert call_log[0][1] == call_log[1][1]