diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 03b7fd4..c8ce231 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -231,10 +231,12 @@ def init( resolved_key = api_key or os.getenv("ADRIAN_API_KEY") or None resolved_file = Path(os.getenv("ADRIAN_LOG_FILE", str(log_file))) - # Default to a local self-hosted backend (the one `make dev` brings - # up at deploy/compose.yaml). OSS users pointing at a remote - # deployment override via ws_url= or ADRIAN_WS_URL. - resolved_ws_url = os.getenv("ADRIAN_WS_URL") or ws_url or "ws://localhost:8080/ws" + # Default to the hosted Adrian backend so `adrian.init(api_key=...)` + # Just Works for freemium users. Self-hosted users override via + # ws_url= or ADRIAN_WS_URL. + resolved_ws_url = ( + os.getenv("ADRIAN_WS_URL") or ws_url or "wss://adrian.secureagentics.ai/ws" + ) resolved_session = ( os.getenv("ADRIAN_SESSION_ID") or session_id or resolve_session_id() ) @@ -520,6 +522,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_base_tool() + _patch_agent_executor() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -531,12 +535,14 @@ def _auto_instrument_langchain() -> None: def _patch_runnable() -> None: - """Patch ``Runnable.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(Runnable, "_adrian_patched", False): return original_invoke = Runnable.invoke original_ainvoke = Runnable.ainvoke + original_astream = Runnable.astream + original_stream = Runnable.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -544,9 +550,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync Runnable call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config, **kwargs) async def patched_ainvoke( @@ -555,13 +559,32 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async Runnable call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config, **kwargs) + Runnable.invoke = patched_invoke # type: ignore[assignment] Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] + Runnable.astream = patched_astream # type: ignore[assignment] + Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] logger.debug("Patched Runnable.invoke / ainvoke") @@ -634,12 +657,14 @@ def patched_configure( def _patch_chat_model() -> None: - """Patch ``BaseChatModel.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(BaseChatModel, "_adrian_chat_model_patched", False): return original_invoke = BaseChatModel.invoke original_ainvoke = BaseChatModel.ainvoke + original_astream = BaseChatModel.astream + original_stream = BaseChatModel.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -647,9 +672,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync chat model call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -658,13 +681,32 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async chat model call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config=config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config=config, **kwargs) + BaseChatModel.invoke = patched_invoke # type: ignore[assignment] BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseChatModel.astream = patched_astream # type: ignore[assignment] + BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] logger.debug("Patched BaseChatModel.invoke / ainvoke") @@ -760,29 +802,15 @@ async def patched_astream( # --- 5. ToolNode --- -def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage], +def _extract_tool_calls( # pyright: ignore[reportUnusedFunction] + state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, Any]]: - """Extract tool_calls from the ToolNode input. + """Extract tool_calls from ToolNode input (all three dispatch shapes). - ``ToolNode`` is reached with three input shapes: - 1. a state dict whose ``"messages"`` key holds the message list - (hand-built ``StateGraph`` with ``ToolNode`` as a node), or - 2. a bare list of messages, or - 3. a single per-tool-call dict ``{"__type", "tool_call", "state"}`` - — how langgraph-prebuilt / ``create_react_agent`` dispatch each - tool call. The id lives at ``input["tool_call"]["id"]``. - - Shape 3 was previously unhandled: the function returned ``[]``, so the - block/HITL gate never found a tool_call_id and ran the tool un-gated. - - Args: - state: The ToolNode input (any of the three shapes above). - - Returns: - List of tool call dicts, or an empty list when none is found. + Returns full tool_call dicts (with id, name, args) for backward + compat with tests and callers that need the full shape. """ - # Shape 3: per-tool-call dispatch (create_react_agent / prebuilt ToolNode). + # Shape 3: per-tool-call dict from _afunc dispatch if isinstance(state, dict) and "tool_call" in state: tc = state["tool_call"] if isinstance(tc, dict) and tc.get("id"): @@ -798,10 +826,13 @@ def _extract_tool_calls( ] return [] + # Shape 1/2: state dict or message list if isinstance(state, dict): messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - else: + elif isinstance(state, list): messages = list(state) + else: + return [] for msg in reversed(messages): if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): @@ -813,55 +844,28 @@ def _extract_tool_calls( def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override everything: ``continue_execution=False`` - means halt, ``True`` means continue. Otherwise the per-MAD policy - bool is the sole scope authority, if the verdict's tier is - in-scope, halt; if not, continue. + HITL resolutions override per-MAD policy when present. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution mad_prefix = verdict.mad_code[:2] - in_scope = { + return { "M0": verdict.policy.policy_m0, "M2": verdict.policy.policy_m2, "M3": verdict.policy.policy_m3, "M4": verdict.policy.policy_m4, }.get(mad_prefix, False) - return in_scope - - -def _build_blocked_response( - tool_calls: list[dict[str, str]], -) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls. - - Args: - tool_calls: List of tool call dicts extracted from the AIMessage. - - Returns: - Dict in the format ToolNode expects. - """ - blocked_messages: list[ToolMessage] = [ - ToolMessage( - content="[BLOCKED by security policy]", - tool_call_id=str(tc.get("id", "")), - name=str(tc.get("name", "")), - ) - for tc in tool_calls - ] - - return {"messages": blocked_messages} - def _patch_tool_node() -> None: - """Patch ``ToolNode.invoke`` / ``ainvoke``. + """Patch ToolNode for callback injection + async verdict gate. - In block mode, the async patch waits for the preceding LLM's verdict - before executing tools. On BLOCK (unless overridden by ``on_block``) - it returns synthetic ``ToolMessage`` responses instead of running the - tools. On timeout it fails open. + ToolNode dispatches tools via tool.invoke (sync) even within async + Pregel. BaseTool.invoke can't await a verdict from the event loop + thread, so we add the verdict gate here on ToolNode.ainvoke - the + entry point Pregel calls before tool dispatch begins. This is a + complementary gate to BaseTool (which covers direct callers). """ try: from langgraph.prebuilt import ToolNode @@ -873,43 +877,96 @@ def _patch_tool_node() -> None: original_invoke = ToolNode.invoke original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) def patched_invoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync ToolNode invocation.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks; in BLOCK / HITL modes wait for verdict. + config = _inject_callbacks(config) + # Verdict gate removed - BaseTool.ainvoke/arun is the single + # gate layer. Gating here too caused double-gate: ToolNode + # consumed the verdict future, BaseTool's gate registered a + # fresh future that never resolved → 30s timeout on a benign + # verdict. Callback injection is kept so events still flow. + return await original_ainvoke(self, input, config=config, **kwargs) - Per-tool-call correlation: every tool_call.id is mapped (in - ``WebSocketClient`` ) to the event_id of the LLM that emitted - it. Each ToolNode invocation awaits its specific LLM's verdict, - race-free under parallel agents, no graph-wide pause. - """ + async def patched_astream( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - ws = _ws_client + assert original_astream is not None # guarded by line below + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode.invoke / ainvoke / astream") + + +# --- 6. BaseTool (universal verdict gate) --- + + +_BLOCKED_CONTENT = "[BLOCKED by security policy]" + + +def _patch_base_tool() -> None: + """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. + + Every LangChain tool - whether dispatched by ToolNode, AgentExecutor, + create_react_agent, or a manual ``tool.invoke(tool_call)`` loop - + funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` + (async). Gating here covers all frameworks in one place. + The gate extracts ``tool_call_id`` from the input (a ``ToolCall`` + TypedDict), awaits the classifier verdict for the producing LLM + event, and returns a ``[BLOCKED]`` string instead of running the + tool body when the verdict is in-scope (M3/M4 under MODE_BLOCK). + + In MODE_BLOCK, verdict timeout is fail-closed (block the tool) + because the absence of a verdict in block mode is a policy violation. + In MODE_ALERT, no gate fires at all (skip). + """ + from langchain_core.tools import BaseTool + from langchain_core.tools.base import ( + _is_tool_call, # pyright: ignore[reportPrivateUsage] + ) + + if getattr(BaseTool, "_adrian_base_tool_patched", False): + return + + original_invoke = BaseTool.invoke + original_ainvoke = BaseTool.ainvoke + + def _extract_tool_call_id(input: Any) -> str | None: # noqa: A002, ANN401 + """Extract tool_call_id from a ToolCall input, or None.""" + if isinstance(input, dict) and _is_tool_call(input): + return input.get("id") + return None + + async def _async_gate(tool_call_id: str) -> bool: + """Returns True if the tool should be BLOCKED.""" + ws = _ws_client if ws is None: - return await original_ainvoke(self, input, config=config, **kwargs) + return False - # First-tool-call window: the recv loop may not have processed - # ``LoginAck`` yet, so ``policy_active()`` reads False even - # when the org is in BLOCK or HITL. Wait for the LoginAck - # event before checking. If it doesn't arrive within the - # window, halt, refusing to run is the only safe outcome - # when we can't verify the org's policy. if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] try: await asyncio.wait_for( @@ -918,36 +975,26 @@ async def patched_ainvoke( ) except TimeoutError: logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" + "BaseTool: LoginAck not received within 5s; " + "blocking tool (refusing to run without verified policy)" ) - return _build_blocked_response(_extract_tool_calls(input)) + return True if not ws.policy_active(): - return await original_ainvoke(self, input, config=config, **kwargs) - - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) - - if not tool_call_id: - # Direct ToolNode invocation outside an LLM flow, no - # producing event_id to wait on, so let the tool run. - return await original_ainvoke(self, input, config=config, **kwargs) + return False cfg = _get_config() timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) if verdict is None: + # Fail-closed in block mode: no verdict = block. logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", + "BaseTool: verdict timeout for tool_call_id=%s; " + "blocking (fail-closed in MODE_BLOCK)", tool_call_id, ) - return await original_ainvoke(self, input, config=config, **kwargs) + return True if _should_halt(verdict): logger.warning( @@ -955,11 +1002,208 @@ async def patched_ainvoke( verdict.event_id, verdict.mad_code, ) - return _build_blocked_response(tool_calls) + return True + + return False + + def _sync_gate(tool_call_id: str) -> bool: + """Sync verdict gate - works for pure-sync and worker-thread callers. + + Worker-thread (the common LangGraph case: ``StructuredTool.ainvoke`` + dispatches a *sync* tool via ``run_in_executor(self.invoke)``, so the + gate runs on a thread-pool worker while the WS event loop runs on + another thread): bridges the async gate onto the WS loop via + ``run_coroutine_threadsafe`` and blocks the worker until the verdict + resolves. + + Pure-sync (no event loop anywhere): runs ``_async_gate`` to + completion on this thread. + + Event-loop thread (calling ``tool.invoke`` directly from async + code): cannot block without deadlocking - returns False (skip). + The async path (``BaseTool.ainvoke``) handles this case. + + Thread detection uses ``asyncio.get_running_loop()`` rather than + ``get_event_loop()``: the latter raises ``RuntimeError`` on a worker + thread (no loop *set* there, since Python 3.10+), which would + misclassify the worker-thread case as "no loop" and skip the gate - + leaving sync tools ungated under ``create_react_agent``. + """ + ws = _ws_client + if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] + return False + + # Is THIS thread running an event loop? + try: + asyncio.get_running_loop() + except RuntimeError: + pass # no loop on this thread: worker thread or pure-sync caller + else: + # On the event-loop thread - can't block it. The async gate + # (BaseTool.ainvoke) covers direct-from-async callers. + return False + + # Worker thread: the WS loop runs elsewhere - bridge onto it and + # block this worker until the verdict resolves. ``_async_gate`` owns + # the wait policy (bounded with fail-closed in MODE_BLOCK, indefinite + # in MODE_HITL where execution must pause until a human acts), so we + # wait on the future with no timeout of our own - a finite timeout + # here would fail-open a HITL hold once it elapsed. Fail closed (treat + # as halt) if the bridge itself raises. + main_loop = getattr(ws, "_loop", None) + if main_loop is not None and main_loop.is_running(): + try: + future = asyncio.run_coroutine_threadsafe( + _async_gate(tool_call_id), main_loop + ) + return future.result() + except Exception: + return True + + # Pure-sync caller, no loop anywhere - run the gate to completion. + try: + return asyncio.run(_async_gate(tool_call_id)) + except Exception: + return True + + def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 + """Return a blocked response compatible with ToolNode. + + Returns a ToolMessage for create_react_agent / ToolNode + compatibility. Falls back to bare string on import failure. + """ + try: + return ToolMessage(content=_BLOCKED_CONTENT, tool_call_id=tc_id, name="") + except Exception: + return _BLOCKED_CONTENT + + def patched_invoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + tc_id = _extract_tool_call_id(input) + if tc_id and _sync_gate(tc_id): + return _blocked_response(tc_id) + return original_invoke(self, input, config=config, **kwargs) + async def patched_ainvoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + tc_id = _extract_tool_call_id(input) + if tc_id and await _async_gate(tc_id): + return _blocked_response(tc_id) return await original_ainvoke(self, input, config=config, **kwargs) - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode.invoke / ainvoke") + original_arun = BaseTool.arun + + async def patched_arun( + self: Any, # noqa: ANN401 + tool_input: Any, # noqa: ANN401 + *args: Any, + tool_call_id: str | None = None, + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Gate on arun - AgentExecutor calls tool.arun directly.""" + if tool_call_id and await _async_gate(tool_call_id): + return _blocked_response(tool_call_id) + return await original_arun( + self, tool_input, *args, tool_call_id=tool_call_id, **kwargs + ) + + BaseTool.invoke = patched_invoke # type: ignore[assignment] + BaseTool.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseTool.arun = patched_arun # type: ignore[assignment] + BaseTool._adrian_base_tool_patched = True # type: ignore[attr-defined] + logger.debug("Patched BaseTool.invoke / ainvoke / arun (universal verdict gate)") + + +# --- 7. AgentExecutor (tool_call_id on agent_action, not on tool.arun) --- + + +def _patch_agent_executor() -> None: + """Patch AgentExecutor._aperform_agent_action for the executor path. + + AgentExecutor calls tool.arun without forwarding tool_call_id, + so the BaseTool.arun gate can't extract it. The tool_call_id lives + on agent_action.tool_call_id (set by OpenAI-style parsers). We + intercept here, await the verdict, and return a blocked observation + instead of calling the tool. + """ + AgentExecutor = None + AgentStep = None + for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): + try: + mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) + AgentExecutor = getattr(mod, "AgentExecutor", None) + AgentStep = getattr(mod, "AgentStep", None) + if AgentExecutor and AgentStep: + break + except ImportError: + continue + + if AgentExecutor is None or AgentStep is None: + return + if getattr(AgentExecutor, "_adrian_executor_patched", False): + return + + original_aperform = AgentExecutor._aperform_agent_action + + async def patched_aperform( + self: Any, + name_to_tool_map: Any, + color_mapping: Any, # noqa: ANN401 + agent_action: Any, + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + tc_id = getattr(agent_action, "tool_call_id", None) + if tc_id: + ws = _ws_client + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; blocking" + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if ws.policy_active(): + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager + ) + + AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] + AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] + logger.debug("Patched AgentExecutor._aperform_agent_action") diff --git a/sdk/python/adrian/ws.py b/sdk/python/adrian/ws.py index 1ab5df4..9eb4bc3 100644 --- a/sdk/python/adrian/ws.py +++ b/sdk/python/adrian/ws.py @@ -52,6 +52,8 @@ _MAX_RUN_ID_MAP = 1024 # Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). _MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 _DEFAULT_REPLAY_BUFFER_FRAMES = 1000 @@ -254,6 +256,10 @@ def __init__( # Set by close() so _handle_disconnect knows not to spawn a reconnect # during a graceful shutdown. self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None # Futures awaited by the patched ToolNode.ainvoke when the # active mode requires a wait (BLOCK or HITL). Each resolves # with the matching ``Verdict`` proto. Futures survive a @@ -472,6 +478,7 @@ async def connect(self) -> None: backoff = _INITIAL_BACKOFF loop = asyncio.get_running_loop() + self._loop = loop headers: dict[str, str] = {} @@ -491,7 +498,6 @@ async def connect(self) -> None: disconnected_at = self._disconnected_at is_reconnect = disconnected_at is not None - if disconnected_at is not None: downtime = time.monotonic() - disconnected_at self._disconnected_at = None @@ -927,6 +933,18 @@ def register_pending( return fut + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + async def wait_for_verdict( self, event_id: str, @@ -939,25 +957,30 @@ async def wait_for_verdict( ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the verdict, or ``None`` on timeout (fail-open). - Cleans up the ``_pending_verdicts`` entry on either path: - ``_on_verdict_frame`` only resolves the future, the dict - ownership belongs here so a late ``register_pending`` after the - verdict has already arrived can still find the resolved future. + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. """ fut = self.register_pending(event_id) try: - return await asyncio.wait_for(fut, timeout=timeout) + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result except TimeoutError: logger.warning( "Verdict timeout for event_id=%s after %ss", event_id, timeout, ) - - return None - finally: + # Timed-out future is useless - remove so a retry can + # register a fresh one. self._pending_verdicts.pop(event_id, None) + return None async def wait_for_tool_verdict( self, diff --git a/sdk/python/tests/test_block_mode.py b/sdk/python/tests/test_block_mode.py index 0d1c352..742249b 100644 --- a/sdk/python/tests/test_block_mode.py +++ b/sdk/python/tests/test_block_mode.py @@ -142,10 +142,16 @@ async def test_looks_up_llm_event_id_and_resolves(self) -> None: class TestToolNodePatchBlocking: async def test_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None: - """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → halt with synthetic ToolMessage.""" + """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → BaseTool.ainvoke gate blocks. - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + The verdict gate lives on BaseTool (the universal layer), not + ToolNode.ainvoke. Uses an async tool so BaseTool.ainvoke (not + BaseTool.invoke) is the entry point - matching the production + path for create_react_agent with async tools. + """ + + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" _real_tool.called = True # type: ignore[attr-defined] return x @@ -180,6 +186,7 @@ def _real_tool(x: str) -> str: result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + # BaseTool.ainvoke gate blocks - tool body does NOT run. assert _real_tool.called is False # type: ignore[attr-defined] msgs = result["messages"] assert len(msgs) == 1 @@ -190,7 +197,7 @@ async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) @@ -226,11 +233,12 @@ def _real_tool(x: str) -> str: assert captured == ["hi"] - async def test_timeout_fail_open_runs_tool(self, tmp_path: Path) -> None: + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: + """Verdict timeout in MODE_BLOCK → fail-closed (tool does NOT run).""" captured: list[str] = [] - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" captured.append(x) return x @@ -248,7 +256,7 @@ def _real_tool(x: str) -> str: _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) ws._connected.set() ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" - # No pending future → wait_for_verdict times out → fail-open. + # No pending future → wait_for_verdict times out → fail-closed (MODE_BLOCK). tool_node = ToolNode([_real_tool]) ai = AIMessage( @@ -259,7 +267,8 @@ def _real_tool(x: str) -> str: await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] - assert captured == ["hi"] + # Fail-closed: tool should NOT have run. + assert captured == [] class TestModeAlert: @@ -268,7 +277,7 @@ async def test_alert_mode_skips_wait(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) @@ -299,3 +308,251 @@ def _real_tool(x: str) -> str: assert captured == ["hi"] assert not ws._pending_verdicts + + +class TestSyncToolNodeBlocking: + """Regression: sync (``def``) tools dispatched by ToolNode / create_react_agent. + + The tests in ``TestToolNodePatchBlocking`` use ``async def`` tools, so + they exercise ``BaseTool.ainvoke`` (the async gate). A sync ``def`` tool + takes a different path: ``StructuredTool.ainvoke`` has no coroutine, so it + runs ``self.invoke`` via ``run_in_executor`` on a worker thread. The gate + therefore lands in ``BaseTool.invoke`` -> ``_sync_gate`` on a thread that + is not running an event loop, and ``_sync_gate`` must bridge the gate onto + the WS loop. A regression here (e.g. probing the thread with + ``get_event_loop()``, which raises on a worker thread) silently skips the + gate and lets block-level tool calls run ungated under create_react_agent. + """ + + @staticmethod + def _prep(ws: WebSocketClient, policy_m4: bool, mad_code: str) -> None: + """Drive a logged-in MODE_BLOCK state with a pre-resolved verdict. + + ``ws._loop`` points at the test loop so the worker-thread bridge in + ``_sync_gate`` has a running target, mirroring production where the + WS loop lives on its own thread, separate from the Pregel worker. + """ + policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=policy_m4) + ws._connected.set() + ws._loop = asyncio.get_running_loop() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + fut = ws.register_pending("llm-evt") + fut.set_result(pb.Verdict(event_id="llm-evt", mad_code=mad_code, policy=policy)) + + async def test_sync_tool_block_verdict_halts(self, tmp_path: Path) -> None: + """MODE_BLOCK + policy_m4 + M4 verdict: sync tool body must NOT run.""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + self._prep(ws, policy_m4=True, mad_code="M4_a") + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + # Sync tool body must NOT run; a BLOCKED ToolMessage is returned. + assert captured == [] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + + async def test_sync_tool_out_of_scope_runs(self, tmp_path: Path) -> None: + """MODE_BLOCK, M2 verdict with policy_m2 false: sync tool runs (no over-block).""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + self._prep(ws, policy_m4=True, mad_code="M2") # m2 not in policy scope + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + + @staticmethod + def _prep_hitl( + ws: WebSocketClient, + ) -> tuple[pb.PolicySnapshot, asyncio.Future[pb.Verdict]]: + """MODE_HITL, logged in, with an UNRESOLVED pending verdict (held). + + Returns the policy and the pending future so the test can resolve it + later, standing in for a human approve/reject. + """ + policy = _apply_mode(ws, pb.MODE_HITL, policy_m4=True) + ws._connected.set() + ws._loop = asyncio.get_running_loop() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + fut = ws.register_pending("llm-evt") + return policy, fut + + @staticmethod + def _tool_call_state() -> dict[str, Any]: + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + return {"messages": [ai]} + + async def test_sync_tool_hitl_holds_until_human_then_blocks_on_reject( + self, tmp_path: Path + ) -> None: + """MODE_HITL: a sync tool is HELD indefinitely, never fail-opens. + + The gate must wait past ``block_timeout`` (the bounded MODE_BLOCK wait + does not apply to HITL); a human reject then halts the tool. Regression + for the worker-thread bridge fail-opening a HITL hold once a finite + ``future.result`` timeout elapsed. + """ + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.5, + ) + + ws = adrian._ws_client + assert ws is not None + policy, fut = self._prep_hitl(ws) + + task = asyncio.ensure_future( + ToolNode([_real_tool]).ainvoke( # pyright: ignore[reportUnknownMemberType] + self._tool_call_state(), config=_runtime_config() + ) + ) + + # Held well past block_timeout: neither run nor returned, waiting for a human. + await asyncio.sleep(1.5) + assert not task.done() + assert captured == [] + + # Human rejects -> HITL verdict with continue_execution=False. + verdict = pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy) + verdict.hitl.continue_execution = False + fut.set_result(verdict) + + result = await asyncio.wait_for(task, timeout=2.0) + assert captured == [] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + + async def test_sync_tool_hitl_resumes_on_approve(self, tmp_path: Path) -> None: + """MODE_HITL: after a human approve (continue_execution=True), the sync tool runs.""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.5, + ) + + ws = adrian._ws_client + assert ws is not None + policy, fut = self._prep_hitl(ws) + + task = asyncio.ensure_future( + ToolNode([_real_tool]).ainvoke( # pyright: ignore[reportUnknownMemberType] + self._tool_call_state(), config=_runtime_config() + ) + ) + + await asyncio.sleep(0.3) + assert not task.done() + assert captured == [] + + verdict = pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy) + verdict.hitl.continue_execution = True + fut.set_result(verdict) + + await asyncio.wait_for(task, timeout=2.0) + assert captured == ["hi"] + + async def test_sync_tool_block_timeout_fails_closed(self, tmp_path: Path) -> None: + """MODE_BLOCK: no verdict before block_timeout -> sync tool blocked (fail-closed).""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Sync tool stub; records execution.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=0.1, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) + ws._connected.set() + ws._loop = asyncio.get_running_loop() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + ws.register_pending("llm-evt") # never resolved -> verdict times out + + result = await ToolNode([_real_tool]).ainvoke( # pyright: ignore[reportUnknownMemberType] + self._tool_call_state(), config=_runtime_config() + ) + + assert captured == [] + assert "BLOCKED" in result["messages"][0].content diff --git a/sdk/python/tests/test_block_mode_races.py b/sdk/python/tests/test_block_mode_races.py index fa0ad57..16d8e4a 100644 --- a/sdk/python/tests/test_block_mode_races.py +++ b/sdk/python/tests/test_block_mode_races.py @@ -5,17 +5,17 @@ LLM calls; no running backend. Scenarios mirror the validated shapes from the multi-agent work: - S1 subagents-as-tools - director → worker (nested) - S2 handoffs - triage → specialist (sequential) - S3 router - parallel fan-out via Send() - S4 hierarchical - 3-level deep (director → team-lead → worker) - S5 custom workflow - deterministic + LLM nodes mixed - S6 swarm - back-and-forth handoffs (Alice ↔ Bob) - S7 supervisor - central dispatcher to N workers - S8 deep research - parallel researchers via asyncio.gather + S1 subagents-as-tools , director → worker (nested) + S2 handoffs , triage → specialist (sequential) + S3 router , parallel fan-out via Send() + S4 hierarchical , 3-level deep (director → team-lead → worker) + S5 custom workflow , deterministic + LLM nodes mixed + S6 swarm , back-and-forth handoffs (Alice ↔ Bob) + S7 supervisor , central dispatcher to N workers + S8 deep research , parallel researchers via asyncio.gather The invariant under test: for EVERY pattern, each ToolNode invocation -blocks on the verdict of the LLM that emitted its specific tool_call.id - +blocks on the verdict of the LLM that emitted its specific tool_call.id , never a sibling, never a parent, never a stale global. """ @@ -117,9 +117,9 @@ def _init_block_mode(tmp_path: Path, block_timeout: float = 1.0) -> Any: def _tool(name: str, captured: list[str]) -> Any: - """Build a named stub tool that records its argument.""" + """Build a named async stub tool that records its argument.""" - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(f"{name}:{x}") diff --git a/sdk/python/tests/test_exec_modes.py b/sdk/python/tests/test_exec_modes.py index 1ea8ae1..f3f5e42 100644 --- a/sdk/python/tests/test_exec_modes.py +++ b/sdk/python/tests/test_exec_modes.py @@ -61,7 +61,7 @@ def _cleanup() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] def _stub_tool(captured: list[str]) -> Any: # noqa: ANN401 - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(x) diff --git a/sdk/python/tests/test_extract_tool_calls.py b/sdk/python/tests/test_extract_tool_calls.py index 9910673..9cad0d4 100644 --- a/sdk/python/tests/test_extract_tool_calls.py +++ b/sdk/python/tests/test_extract_tool_calls.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2026 SecureAgentics -"""Unit tests for ``_extract_tool_calls`` — the function whose missing shape +"""Unit tests for ``_extract_tool_calls`` - the function whose missing shape handling let block/HITL skip the verdict wait for ``create_react_agent`` agents. Covers all three ToolNode input shapes. Shape 3 (per-tool-call dispatch) is the diff --git a/sdk/python/tests/test_parent_context_scenarios.py b/sdk/python/tests/test_parent_context_scenarios.py index 157b5e1..327884b 100644 --- a/sdk/python/tests/test_parent_context_scenarios.py +++ b/sdk/python/tests/test_parent_context_scenarios.py @@ -1,4 +1,4 @@ -"""End-to-end parent-context derivation per multi-agent scenario (S1–S8). +"""End-to-end parent-context derivation per multi-agent scenario (S1-S8). Fires the LangChain-shaped callback sequence each scenario produces - with the ``langgraph_checkpoint_ns`` metadata LangGraph would emit -