From 5d7f71f4e390bc2fb5d7b8b71d52390dcc802bbc Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 25 Jun 2026 05:32:40 +0000 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20fail-closed=20driver=20execution=20?= =?UTF-8?q?=E2=80=94=20audit,=20deadlines,=20pooled=20HTTP=20client?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Close four related gaps on the invocation path so the kernel's "controlled, audited execution" promise (I-02) holds on every exit, not just the happy path. Grouped because they share one code area (kernel/_invoke.py + drivers/http.py) and one fail-closed implementation path: any fault funnels through DriverError -> fallback -> failure trace -> budget release. - #152: capture *any* driver exception (not only DriverError) as a failed attempt, and wrap the post-driver pipeline (handle store, firewall transform, token counting) so a fault there records a failure trace and releases the reservation exactly once before re-raising. Extract execute_with_fallback into kernel/_driver_exec.py to keep _invoke.py within its module-size budget (lower its ratchet ceiling 400 -> 382). - #191: optional per-invocation deadline via the signed `invoke_timeout_s` token constraint, enforced per driver attempt (single-shot and streaming, incl. a stream inactivity timeout) with asyncio.wait_for; a timeout becomes a synthetic DriverError. Sourced from the token (not a kernel ctor arg) so the deadline is tamper-evident and the kernel module stays within its size budget. - #194: HTTPDriver holds one long-lived, pooled httpx.AsyncClient with configurable Limits and an aclose() lifecycle; optional max_response_bytes streams and aborts oversized bodies before buffering. - #197: a non-JSON body from a JSON endpoint raises a typed DriverError (content-type + redaction-safe snippet); new HTTPEndpoint.response_format ("json"|"text") supports text APIs deliberately. Tests: new tests/test_driver_exec.py (fallback, deadline, fault capture); rewritten HTTP driver tests for the pooled-client/streaming contract; kernel integration tests for non-DriverError audit, firewall-failure audit, budget release, and deadline enforcement. Docs + CHANGELOG updated. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_018VdP3irZSbbPyyoVS1QQbD --- CHANGELOG.md | 25 ++ docs/integrations.md | 40 ++- examples/http_driver_demo.py | 4 +- src/weaver_kernel/drivers/http.py | 167 ++++++++++--- src/weaver_kernel/kernel/_driver_exec.py | 151 ++++++++++++ src/weaver_kernel/kernel/_invoke.py | 94 +++---- src/weaver_kernel/kernel/_stream.py | 36 ++- tests/test_architecture.py | 2 +- tests/test_driver_exec.py | 152 ++++++++++++ tests/test_drivers.py | 298 ++++++++++++++--------- tests/test_kernel.py | 155 +++++++++++- 11 files changed, 920 insertions(+), 204 deletions(-) create mode 100644 src/weaver_kernel/kernel/_driver_exec.py create mode 100644 tests/test_driver_exec.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fa1d80..c89903f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- **Fail-closed driver execution.** A grouped pass over the invocation path so + the kernel's "controlled, audited execution" promise (I-02) holds on *every* + exit, not just the happy path: + - **Audit + budget release on any driver fault (#152).** `execute_with_fallback` + now treats *any* exception a driver raises — not only `DriverError` — as a + failed attempt, so an unexpected error is captured and surfaced instead of + escaping un-audited with the budget reservation leaked. The post-driver + pipeline in `perform_invoke` (handle creation, firewall transform, token + counting) is likewise wrapped: a fault there releases the reservation + exactly once and records a failure `ActionTrace` before re-raising. + - **Per-invocation deadline (#191).** An optional `invoke_timeout_s` token + constraint bounds each driver attempt (single-shot and streaming, including + a stream inactivity timeout) via `asyncio.wait_for`. A timeout becomes a + synthetic `DriverError`, so the existing fallback and failure-trace paths + apply unchanged. Signed into the token, so the deadline is tamper-evident + and bound to the grant. Defaults to off. + - **Pooled HTTP client + response-size guard (#194).** `HTTPDriver` holds a + single long-lived `httpx.AsyncClient` (connection pooling, configurable + `httpx.Limits`) instead of building one per request, with an `aclose()` for + shutdown. An optional `max_response_bytes` streams and aborts oversized + bodies with a `DriverError` before they are fully buffered. + - **Defensive HTTP body parsing (#197).** A JSON endpoint returning a + non-JSON body now raises a typed `DriverError` (with content-type and a + redaction-safe snippet) instead of leaking `json.JSONDecodeError`. A new + `HTTPEndpoint.response_format="text"` supports text APIs deliberately. - **Context-firewall sizing, budgeting & summary fidelity.** A grouped pass over how the firewall measures, bounds, and represents payloads: - **Allocation-free size estimation (#207).** `firewall.estimated_size` walks diff --git a/docs/integrations.md b/docs/integrations.md index d82a857..bb31b51 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -92,12 +92,18 @@ asyncio.run(main()) ## HTTPDriver -The built-in `HTTPDriver` supports GET, POST, PUT, DELETE: +The built-in `HTTPDriver` supports GET, POST, PUT, DELETE (and any other method +via the generic path): ```python from weaver_kernel.drivers.http import HTTPDriver, HTTPEndpoint -driver = HTTPDriver(driver_id="my_api") +driver = HTTPDriver( + driver_id="my_api", + # Optional response-size guard: reject bodies larger than this before they + # are fully buffered, so an unbounded upstream cannot exhaust memory (#194). + max_response_bytes=5_000_000, +) driver.register_endpoint("users.list", HTTPEndpoint( url="https://api.example.com/users", method="GET", @@ -106,6 +112,36 @@ driver.register_endpoint("users.list", HTTPEndpoint( kernel.register_driver(driver) ``` +The driver holds a single long-lived `httpx.AsyncClient` so requests reuse the +connection pool and keep-alive instead of opening a fresh connection per call +(#194). You own its lifecycle — call `await driver.aclose()` on shutdown (e.g. +in a `finally` block) to release the pool. + +A non-JSON body from a JSON endpoint raises a typed `DriverError` rather than +leaking a decode error (#197). For text APIs, set `response_format="text"` on +the endpoint to receive the decoded body verbatim: + +```python +driver.register_endpoint("status.page", HTTPEndpoint( + url="https://api.example.com/status", + response_format="text", +)) +``` + +### Bounding execution time + +Any driver — HTTP, MCP, or custom — can be bounded by a per-invocation +deadline. Set the `invoke_timeout_s` constraint when the policy issues the +grant; because constraints are signed into the capability token, the deadline +is tamper-evident and travels with the grant. An attempt that exceeds it is +turned into a `DriverError`, so the kernel still records a failure trace and +releases any reserved budget (#191): + +```python +# A policy engine that attaches a 10s deadline to issued tokens: +decision.constraints["invoke_timeout_s"] = 10.0 +``` + ## Custom drivers Any object implementing the `Driver` protocol can be registered: diff --git a/examples/http_driver_demo.py b/examples/http_driver_demo.py index 50d579c..9abded4 100644 --- a/examples/http_driver_demo.py +++ b/examples/http_driver_demo.py @@ -63,6 +63,7 @@ def _start_server(port: int) -> HTTPServer: async def main() -> None: port = 18765 server = _start_server(port) + http_driver = HTTPDriver(driver_id="catalog_api") try: registry = CapabilityRegistry() @@ -76,7 +77,6 @@ async def main() -> None: ) ) - http_driver = HTTPDriver(driver_id="catalog_api") http_driver.register_endpoint( "catalog.list_products", HTTPEndpoint(url=f"http://127.0.0.1:{port}/products", method="GET"), @@ -131,6 +131,8 @@ async def main() -> None: print("\n✓ http_driver_demo.py complete.") finally: + # The driver owns a pooled httpx client; close it on shutdown (#194). + await http_driver.aclose() server.shutdown() diff --git a/src/weaver_kernel/drivers/http.py b/src/weaver_kernel/drivers/http.py index 1365d08..b3e502a 100644 --- a/src/weaver_kernel/drivers/http.py +++ b/src/weaver_kernel/drivers/http.py @@ -2,8 +2,9 @@ from __future__ import annotations +import json from dataclasses import dataclass, field -from typing import Any +from typing import Any, Literal import httpx @@ -11,6 +12,8 @@ from ..models import RawResult from .base import ExecutionContext +_DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) + @dataclass class HTTPEndpoint: @@ -21,14 +24,20 @@ class HTTPEndpoint: headers: dict[str, str] = field(default_factory=dict) timeout: float | None = None """Per-endpoint timeout in seconds. Falls back to the driver's ``default_timeout``.""" + response_format: Literal["json", "text"] = "json" + """How to read a successful body: parse as JSON (default) or keep it as text.""" class HTTPDriver: """A driver that invokes capabilities via HTTP using :mod:`httpx`. - Each operation must be registered with an :class:`HTTPEndpoint`. - The driver performs *synchronous* execution inside an async method by - using ``httpx.AsyncClient`` for proper async support. + Each operation must be registered with an :class:`HTTPEndpoint`. The driver + holds a single long-lived :class:`httpx.AsyncClient` so requests reuse the + connection pool and keep-alive instead of paying a fresh TLS handshake on + every call (#194); call :meth:`aclose` on shutdown to release it. Bodies are + size-bounded (``max_response_bytes``) and parsed defensively — a non-JSON + body from a JSON endpoint raises :class:`DriverError` rather than leaking a + raw decode error (#197). """ def __init__( @@ -37,11 +46,16 @@ def __init__( *, base_headers: dict[str, str] | None = None, default_timeout: float = 30.0, + limits: httpx.Limits | None = None, + max_response_bytes: int | None = None, ) -> None: self._driver_id = driver_id self._endpoints: dict[str, HTTPEndpoint] = {} self._base_headers = base_headers or {} self._default_timeout = default_timeout + self._limits = limits or _DEFAULT_LIMITS + self._max_response_bytes = max_response_bytes + self._client: httpx.AsyncClient | None = None @property def driver_id(self) -> str: @@ -57,6 +71,31 @@ def register_endpoint(self, operation: str, endpoint: HTTPEndpoint) -> None: """ self._endpoints[operation] = endpoint + def _get_client(self) -> httpx.AsyncClient: + """Return the shared client, creating it on first use. + + Built lazily so the connection pool, default headers, and limits are + established once and reused across invocations (#194). + """ + if self._client is None: + self._client = httpx.AsyncClient( + headers=self._base_headers, + timeout=self._default_timeout, + limits=self._limits, + ) + return self._client + + async def aclose(self) -> None: + """Close the shared client and release its connection pool. + + Idempotent — safe to call more than once. Callers that construct an + :class:`HTTPDriver` own its lifecycle and should call this on shutdown + (e.g. in a ``finally`` block or async-context teardown). + """ + if self._client is not None: + await self._client.aclose() + self._client = None + async def execute(self, ctx: ExecutionContext) -> RawResult: """Execute an HTTP request for the given context. @@ -67,10 +106,13 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: ctx: The execution context. Returns: - :class:`RawResult` containing the parsed JSON response. + :class:`RawResult` containing the parsed JSON response, or the raw + text when the endpoint's ``response_format`` is ``"text"``. Raises: - DriverError: If the endpoint is not registered or the request fails. + DriverError: If the endpoint is not registered, the request fails, + the response exceeds ``max_response_bytes``, or a JSON endpoint + returns a body that is not valid JSON. """ operation = str(ctx.args.get("operation", ctx.capability_id)) endpoint = self._endpoints.get(operation) @@ -79,11 +121,10 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: f"HTTPDriver '{self._driver_id}' has no endpoint for operation='{operation}'." ) - headers = {**self._base_headers, **endpoint.headers} + method = endpoint.method.upper() params: dict[str, Any] = {} json_body: dict[str, Any] | None = None - - if endpoint.method.upper() in ("GET", "DELETE"): + if method in ("GET", "DELETE"): params = {k: v for k, v in ctx.args.items() if k != "operation"} else: json_body = {k: v for k, v in ctx.args.items() if k != "operation"} @@ -91,35 +132,103 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: effective_timeout = ( endpoint.timeout if endpoint.timeout is not None else self._default_timeout ) + client = self._get_client() try: - async with httpx.AsyncClient(headers=headers, timeout=effective_timeout) as client: - if endpoint.method.upper() == "GET": - response = await client.get(endpoint.url, params=params) - elif endpoint.method.upper() == "POST": - response = await client.post(endpoint.url, json=json_body) - elif endpoint.method.upper() == "PUT": - response = await client.put(endpoint.url, json=json_body) - elif endpoint.method.upper() == "DELETE": - response = await client.delete(endpoint.url, params=params) - else: - response = await client.request( - endpoint.method.upper(), endpoint.url, json=json_body + async with client.stream( + method, + endpoint.url, + params=params, + json=json_body, + headers=endpoint.headers, + timeout=effective_timeout, + ) as response: + if response.is_error: + await response.aread() + raise DriverError( + f"HTTPDriver '{self._driver_id}': HTTP {response.status_code} " + f"from {endpoint.url}: {response.text[:200]}" ) - response.raise_for_status() - data: Any = response.json() - except httpx.HTTPStatusError as exc: - raise DriverError( - f"HTTPDriver '{self._driver_id}': HTTP {exc.response.status_code} " - f"from {endpoint.url}: {exc.response.text[:200]}" - ) from exc + body = await self._read_bounded(response, url=endpoint.url) + status_code = response.status_code + content_type = response.headers.get("content-type", "") except httpx.RequestError as exc: raise DriverError( f"HTTPDriver '{self._driver_id}': Request to {endpoint.url} failed: {exc}" ) from exc + data = self._decode_body( + body, + response_format=endpoint.response_format, + url=endpoint.url, + content_type=content_type, + ) return RawResult( capability_id=ctx.capability_id, data=data, - metadata={"status_code": response.status_code, "url": endpoint.url}, + metadata={"status_code": status_code, "url": endpoint.url}, ) + + async def _read_bounded(self, response: httpx.Response, *, url: str) -> bytes: + """Read the response body, aborting if it exceeds ``max_response_bytes``. + + Streams chunks so an oversized upstream body is rejected before it is + fully buffered — the firewall's budget only applies *after* a + :class:`RawResult` exists, so the guard has to live here (#194). + + Args: + response: The open streaming response. + url: The request URL, used in the error message. + + Returns: + The full response body as bytes. + + Raises: + DriverError: If the accumulated body exceeds ``max_response_bytes``. + """ + limit = self._max_response_bytes + if limit is None: + return await response.aread() + body = bytearray() + async for chunk in response.aiter_bytes(): + body.extend(chunk) + if len(body) > limit: + raise DriverError( + f"HTTPDriver '{self._driver_id}': response from {url} exceeded " + f"max_response_bytes ({limit})." + ) + return bytes(body) + + def _decode_body( + self, + body: bytes, + *, + response_format: Literal["json", "text"], + url: str, + content_type: str, + ) -> Any: + """Decode a response body per the endpoint's ``response_format``. + + Args: + body: The raw response bytes. + response_format: ``"json"`` to parse, ``"text"`` to decode as a string. + url: The request URL, used in the error message. + content_type: The response ``Content-Type``, used in the error message. + + Returns: + The parsed JSON value (``None`` for an empty body), or the decoded text. + + Raises: + DriverError: If ``response_format`` is ``"json"`` and the body is not + valid JSON (#197). + """ + if response_format == "text": + return body.decode("utf-8", "replace") + try: + return json.loads(body) if body else None + except (json.JSONDecodeError, ValueError) as exc: + snippet = body[:200].decode("utf-8", "replace") + raise DriverError( + f"HTTPDriver '{self._driver_id}': non-JSON response from {url} " + f"(content-type: {content_type or 'unknown'}): {snippet}" + ) from exc diff --git a/src/weaver_kernel/kernel/_driver_exec.py b/src/weaver_kernel/kernel/_driver_exec.py new file mode 100644 index 0000000..b62ee8c --- /dev/null +++ b/src/weaver_kernel/kernel/_driver_exec.py @@ -0,0 +1,151 @@ +"""Driver execution: route-plan fallback, deadlines, and fault capture. + +Extracted from :mod:`._invoke` (AGENTS.md ≤ 300-line budget) so the +driver-execution concern lives in one place: trying a route plan's drivers +in order, bounding each attempt with an optional deadline, and turning +*every* fault into an auditable failed attempt. + +Invariants preserved here (see ``docs/agent-context/invariants.md``): + +* **I-02 (auditability).** Any exception a driver raises — ``DriverError`` or + otherwise — is recorded as a failed attempt and surfaced to + :func:`._invoke.perform_invoke` as ``last_error``, so the caller always + writes a failure :class:`~weaver_kernel.models.ActionTrace` and releases the + budget reservation (#152). A driver fault never escapes un-audited with the + reservation leaked. +* **Bounded execution.** When the token carries an ``invoke_timeout_s`` + constraint, each driver attempt is wrapped in :func:`asyncio.wait_for`; a + timeout becomes a synthetic ``DriverError`` so the existing fallback and + failure-trace paths apply unchanged (#191). +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from ..drivers.base import Driver, ExecutionContext +from ..errors import DriverError +from ..models import RawResult, RoutePlan + +logger = logging.getLogger("weaver_kernel.kernel") + +INVOKE_TIMEOUT_CONSTRAINT = "invoke_timeout_s" +"""Token-constraint key carrying the per-invocation deadline in seconds (#191).""" + + +def resolve_invoke_timeout(constraints: dict[str, Any]) -> float | None: + """Return the per-invocation deadline in seconds, or ``None`` if unset. + + Reads the optional ``"invoke_timeout_s"`` token constraint (#191). Because + constraints are signed into the capability token, the deadline is + tamper-evident and bound to the grant rather than to mutable kernel state. + + Args: + constraints: The verified token's ``constraints`` mapping. + + Returns: + A positive ``float`` deadline, or ``None`` when no deadline applies. + + Raises: + DriverError: If the constraint is present but not a positive number + (booleans are rejected — ``True``/``False`` are not durations). + """ + raw = constraints.get(INVOKE_TIMEOUT_CONSTRAINT) + if raw is None: + return None + if isinstance(raw, bool) or not isinstance(raw, (int, float)) or raw <= 0: + raise DriverError( + f"Invalid '{INVOKE_TIMEOUT_CONSTRAINT}' constraint: {raw!r} " + f"(must be a positive number of seconds)." + ) + return float(raw) + + +async def execute_with_fallback( + drivers: dict[str, Driver], + plan: RoutePlan, + *, + ctx: ExecutionContext, + log_ctx: dict[str, str], + timeout: float | None = None, +) -> tuple[RawResult | None, str, Exception | None, bool]: + """Iterate the route plan's drivers until one succeeds. + + Args: + drivers: The kernel's registered driver map. + plan: The router-resolved route plan to walk in order. + ctx: The execution context handed to each driver. + log_ctx: Structured logging fields propagated to each log record. + timeout: Optional per-attempt deadline in seconds (#191). When set, a + driver attempt exceeding it is converted to a synthetic + ``DriverError`` and treated as a failed attempt. + + Returns: + ``(raw_result, driver_id, last_error, fell_back)``. ``raw_result`` is + ``None`` if every driver failed; ``fell_back`` is ``True`` when at least + one earlier driver raised before the one that ultimately ran (or before + all-failed), so callers can count fallback activations. A route entry + whose driver is unregistered (``drivers.get(driver_id) is None``) is + skipped silently and does **not** set ``fell_back``. + + A ``DriverError`` *and* any other exception a driver raises both count + as a failed attempt (#152); the latter is preserved as ``last_error`` + so the caller still records a failure trace and releases the budget + rather than letting it escape un-audited. + """ + last_error: Exception | None = None + failed_attempts = 0 + for driver_id in plan.driver_ids: + driver = drivers.get(driver_id) + if driver is None: + continue + try: + if timeout is None: + raw_result = await driver.execute(ctx) + else: + raw_result = await asyncio.wait_for(driver.execute(ctx), timeout) + logger.debug("driver_success", extra={**log_ctx, "driver_id": driver_id}) + return raw_result, driver_id, None, failed_attempts > 0 + except asyncio.TimeoutError: + logger.warning( + "driver_timeout", + extra={**log_ctx, "driver_id": driver_id, "timeout_s": timeout}, + ) + last_error = DriverError( + f"Driver '{driver_id}' timed out after {timeout}s " + f"for capability '{ctx.capability_id}'." + ) + failed_attempts += 1 + continue + except DriverError as exc: + logger.warning( + "driver_failure", + extra={**log_ctx, "driver_id": driver_id, "error": str(exc)}, + ) + last_error = exc + failed_attempts += 1 + continue + except Exception as exc: + # I-02: a driver raising an *unexpected* (non-DriverError) exception + # must still be audited. Capture it as a failed attempt so + # perform_invoke records a failure trace and releases the + # reservation (#152) instead of the exception escaping un-traced + # with the budget leaked. + logger.warning( + "driver_failure_unexpected", + extra={ + **log_ctx, + "driver_id": driver_id, + "error_type": type(exc).__name__, + "error": str(exc), + }, + ) + last_error = exc + failed_attempts += 1 + continue + return None, "", last_error, failed_attempts > 0 + + +__all__ = ["execute_with_fallback", "resolve_invoke_timeout", "INVOKE_TIMEOUT_CONSTRAINT"] diff --git a/src/weaver_kernel/kernel/_invoke.py b/src/weaver_kernel/kernel/_invoke.py index 151a3da..0e97bb4 100644 --- a/src/weaver_kernel/kernel/_invoke.py +++ b/src/weaver_kernel/kernel/_invoke.py @@ -21,7 +21,7 @@ from dataclasses import replace from typing import TYPE_CHECKING, Any, cast -from ..drivers.base import Driver, ExecutionContext +from ..drivers.base import ExecutionContext from ..enums import SensitivityTag from ..errors import DriverError from ..firewall.budget_manager import BudgetManager @@ -32,12 +32,12 @@ Frame, Handle, Principal, - RawResult, ResponseMode, RoutePlan, ) from ..stores import TraceStoreProtocol from ..tokens import CapabilityToken +from ._driver_exec import execute_with_fallback, resolve_invoke_timeout if TYPE_CHECKING: # pragma: no cover from . import Kernel @@ -130,45 +130,6 @@ def resolve_effective_mode( return effective -async def execute_with_fallback( - drivers: dict[str, Driver], - plan: RoutePlan, - *, - ctx: ExecutionContext, - log_ctx: dict[str, str], -) -> tuple[RawResult | None, str, Exception | None, bool]: - """Iterate the route plan's drivers until one succeeds. - - Returns: - ``(raw_result, driver_id, last_error, fell_back)``. ``raw_result`` is - ``None`` if every driver failed; ``fell_back`` is ``True`` when at least - one earlier driver raised before the one that ultimately ran (or before - all-failed), so callers can count fallback activations. Only a - ``DriverError`` counts as a failed attempt: a route entry whose driver is - unregistered (``drivers.get(driver_id) is None``) is skipped silently and - does **not** set ``fell_back``. - """ - last_error: Exception | None = None - failed_attempts = 0 - for driver_id in plan.driver_ids: - driver = drivers.get(driver_id) - if driver is None: - continue - try: - raw_result = await driver.execute(ctx) - logger.debug("driver_success", extra={**log_ctx, "driver_id": driver_id}) - return raw_result, driver_id, None, failed_attempts > 0 - except DriverError as exc: - logger.warning( - "driver_failure", - extra={**log_ctx, "driver_id": driver_id, "error": str(exc)}, - ) - last_error = exc - failed_attempts += 1 - continue - return None, "", last_error, failed_attempts > 0 - - def record_failure_trace( *, action_id: str, @@ -260,6 +221,10 @@ async def perform_invoke( the recorded :class:`ActionTrace`. """ action_id = str(uuid.uuid4()) + # Resolve the optional per-invocation deadline (#191) before reserving + # budget so an invalid signed constraint fails fast without leaking a + # reservation. + invoke_timeout = resolve_invoke_timeout(token.constraints) effective_mode = resolve_effective_mode( response_mode=response_mode, principal=principal, @@ -293,7 +258,7 @@ async def perform_invoke( ) downgraded = effective_mode != response_mode raw_result, used_driver_id, last_error, fell_back = await execute_with_fallback( - kernel._driver_map, plan, ctx=ctx, log_ctx=log_ctx + kernel._driver_map, plan, ctx=ctx, log_ctx=log_ctx, timeout=invoke_timeout ) if raw_result is None: @@ -319,18 +284,20 @@ async def perform_invoke( f"All drivers failed for capability '{token.capability_id}'. Last error: {err_msg}" ) + # I-02: faults *after* the driver returned — handle creation, firewall + # transform, token counting, usage accounting — must not escape un-audited + # or leak the reservation. Capture any escape, release the budget exactly + # once, and record a failure trace before re-raising (#152). handle: Handle | None = None - if effective_mode != "raw": - handle = kernel._handles.store( - capability_id=token.capability_id, - data=raw_result.data, - principal_id=principal.principal_id, - constraints=token.constraints, - ) - kernel._stats.on_handle_store() - - reservation_consumed = False try: + if effective_mode != "raw": + handle = kernel._handles.store( + capability_id=token.capability_id, + data=raw_result.data, + principal_id=principal.principal_id, + constraints=token.constraints, + ) + kernel._stats.on_handle_store() frame = kernel._fw.transform( raw_result, action_id=action_id, @@ -343,10 +310,26 @@ async def perform_invoke( if kernel.budget is not None and reserved_tokens is not None: actual_tokens = kernel.budget.count_tokens(_frame_payload(frame)) await kernel.budget.record_usage(actual_tokens, reserved=reserved_tokens) - reservation_consumed = True - finally: - if not reservation_consumed and kernel.budget is not None and reserved_tokens is not None: + except Exception as exc: + if kernel.budget is not None and reserved_tokens is not None: await kernel.budget.release(reserved_tokens) + err_msg = str(exc) + logger.warning("invoke_failure", extra={**log_ctx, "error": err_msg}) + record_failure_trace( + action_id=action_id, + capability_id=token.capability_id, + principal_id=principal.principal_id, + token_id=token.token_id, + args=args, + response_mode=response_mode, + error_message=err_msg, + trace_store=kernel._traces, + sensitivity=capability.sensitivity, + ) + kernel._stats.on_invocation( + failed=True, fallback=fell_back, redacted=False, downgraded=downgraded + ) + raise record_success_trace( action_id=action_id, @@ -394,7 +377,6 @@ def _frame_payload(frame: Frame) -> Any: __all__ = [ "perform_invoke", "resolve_effective_mode", - "execute_with_fallback", "record_failure_trace", "record_success_trace", ] diff --git a/src/weaver_kernel/kernel/_stream.py b/src/weaver_kernel/kernel/_stream.py index e6835c9..086eb3f 100644 --- a/src/weaver_kernel/kernel/_stream.py +++ b/src/weaver_kernel/kernel/_stream.py @@ -15,6 +15,7 @@ from __future__ import annotations +import asyncio import datetime import logging import uuid @@ -34,6 +35,7 @@ RoutePlan, ) from ..tokens import CapabilityToken +from ._driver_exec import resolve_invoke_timeout from ._invoke import _frame_result_summary, _redact_args_for_trace, resolve_effective_mode if TYPE_CHECKING: # pragma: no cover @@ -54,6 +56,7 @@ async def invoke_stream_impl( ) -> AsyncIterator[Frame]: """Stream Frames for one capability invocation.""" action_id = str(uuid.uuid4()) + invoke_timeout = resolve_invoke_timeout(token.constraints) initial_mode = resolve_effective_mode( response_mode=response_mode, principal=principal, @@ -109,6 +112,7 @@ async def invoke_stream_impl( principal=principal, response_mode=initial_mode, action_id=action_id, + timeout=invoke_timeout, ): yielded_any = True redacted_any = redacted_any or bool(frame.warnings) @@ -119,7 +123,16 @@ async def invoke_stream_impl( fallback_driver = kernel._driver_map.get(fallback_driver_id) if fallback_driver is None: raise DriverError(f"No driver available for capability '{token.capability_id}'.") - raw = await fallback_driver.execute(ctx) + try: + if invoke_timeout is None: + raw = await fallback_driver.execute(ctx) + else: + raw = await asyncio.wait_for(fallback_driver.execute(ctx), invoke_timeout) + except asyncio.TimeoutError as exc: + raise DriverError( + f"Driver '{fallback_driver_id}' timed out after {invoke_timeout}s " + f"for capability '{token.capability_id}'." + ) from exc if initial_mode != "raw": handle = kernel._handles.store( capability_id=token.capability_id, @@ -186,6 +199,7 @@ async def _stream_chunks( principal: Principal, response_mode: ResponseMode, action_id: str, + timeout: float | None = None, ) -> AsyncIterator[Frame]: """Yield firewalled frames for each chunk the driver produces. @@ -197,6 +211,11 @@ async def _stream_chunks( honouring ``apply_stream``'s stateless contract. If the driver ends without an explicit ``__is_final__`` marker, a final sentinel chunk is injected so consumers can detect end-of-stream uniformly. + + When *timeout* is set, it is applied as an *inactivity* deadline between + chunks (#191): a driver that stalls longer than ``timeout`` seconds waiting + for the next chunk is aborted with a ``DriverError`` so a hung stream + cannot freeze the agent loop with a budget reservation held. """ effective_mode = resolve_effective_mode( response_mode=response_mode, @@ -206,7 +225,20 @@ async def _stream_chunks( async def _raw_chunks() -> AsyncIterator[dict[str, Any]]: saw_final = False - async for chunk in driver.execute_stream(ctx): + iterator = aiter(driver.execute_stream(ctx)) + while True: + try: + if timeout is None: + chunk = await anext(iterator) + else: + chunk = await asyncio.wait_for(anext(iterator), timeout) + except StopAsyncIteration: + break + except asyncio.TimeoutError as exc: + raise DriverError( + f"Stream for capability '{token.capability_id}' timed out after " + f"{timeout}s waiting for the next chunk." + ) from exc if chunk.get("__is_final__"): saw_final = True yield chunk diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 5af2c4d..7c55630 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -52,7 +52,7 @@ "policy.py": 652, "kernel/__init__.py": 541, "adapters/_base.py": 459, - "kernel/_invoke.py": 400, + "kernel/_invoke.py": 382, "firewall/transform.py": 377, "adapters/openai.py": 358, "stores/sqlite.py": 350, diff --git a/tests/test_driver_exec.py b/tests/test_driver_exec.py new file mode 100644 index 0000000..cd01ef2 --- /dev/null +++ b/tests/test_driver_exec.py @@ -0,0 +1,152 @@ +"""Tests for kernel/_driver_exec.py: fallback, deadlines, and fault capture.""" + +from __future__ import annotations + +import asyncio + +import pytest + +from weaver_kernel.drivers.base import ExecutionContext +from weaver_kernel.errors import DriverError +from weaver_kernel.kernel._driver_exec import execute_with_fallback, resolve_invoke_timeout +from weaver_kernel.models import RawResult, RoutePlan + + +class _OKDriver: + """A driver that returns a fixed payload.""" + + def __init__(self, driver_id: str = "ok", *, payload: object = "ok") -> None: + self._driver_id = driver_id + self._payload = payload + + @property + def driver_id(self) -> str: + return self._driver_id + + async def execute(self, ctx: ExecutionContext) -> RawResult: + return RawResult(capability_id=ctx.capability_id, data=self._payload) + + +class _RaisingDriver: + """A driver whose ``execute`` raises a caller-supplied exception.""" + + def __init__(self, driver_id: str, exc: Exception) -> None: + self._driver_id = driver_id + self._exc = exc + + @property + def driver_id(self) -> str: + return self._driver_id + + async def execute(self, ctx: ExecutionContext) -> RawResult: + raise self._exc + + +class _SlowDriver: + """A driver that sleeps longer than any reasonable test deadline.""" + + def __init__(self, driver_id: str = "slow") -> None: + self._driver_id = driver_id + + @property + def driver_id(self) -> str: + return self._driver_id + + async def execute(self, ctx: ExecutionContext) -> RawResult: + await asyncio.sleep(1.0) + return RawResult(capability_id=ctx.capability_id, data="late") + + +def _ctx() -> ExecutionContext: + return ExecutionContext(capability_id="cap.x", principal_id="u1") + + +def _plan(*driver_ids: str) -> RoutePlan: + return RoutePlan(capability_id="cap.x", driver_ids=list(driver_ids)) + + +# ── resolve_invoke_timeout ───────────────────────────────────────────────────── + + +def test_resolve_invoke_timeout_absent_is_none() -> None: + assert resolve_invoke_timeout({}) is None + + +def test_resolve_invoke_timeout_valid_returns_float() -> None: + assert resolve_invoke_timeout({"invoke_timeout_s": 5}) == 5.0 + + +@pytest.mark.parametrize("bad", [0, -1, "5", True, [1]]) +def test_resolve_invoke_timeout_invalid_raises(bad: object) -> None: + with pytest.raises(DriverError, match="invoke_timeout_s"): + resolve_invoke_timeout({"invoke_timeout_s": bad}) + + +# ── execute_with_fallback: fault capture (#152) ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_unexpected_exception_is_captured_not_raised() -> None: + """A non-DriverError from a driver becomes ``last_error``, not an escape (#152).""" + boom = ValueError("boom") + drivers = {"bad": _RaisingDriver("bad", boom)} + raw, driver_id, last_error, fell_back = await execute_with_fallback( + drivers, _plan("bad"), ctx=_ctx(), log_ctx={} + ) + assert raw is None + assert driver_id == "" + assert last_error is boom + assert fell_back is True + + +@pytest.mark.asyncio +async def test_fallback_runs_after_unexpected_exception() -> None: + """An unexpected exception is a failed attempt; the next driver still runs (#152).""" + drivers = { + "bad": _RaisingDriver("bad", RuntimeError("kaboom")), + "good": _OKDriver("good", payload={"from": "good"}), + } + raw, driver_id, last_error, fell_back = await execute_with_fallback( + drivers, _plan("bad", "good"), ctx=_ctx(), log_ctx={} + ) + assert raw is not None and raw.data == {"from": "good"} + assert driver_id == "good" + assert last_error is None + assert fell_back is True + + +# ── execute_with_fallback: deadline (#191) ───────────────────────────────────── + + +@pytest.mark.asyncio +async def test_timeout_synthesizes_driver_error() -> None: + drivers = {"slow": _SlowDriver("slow")} + raw, driver_id, last_error, _ = await execute_with_fallback( + drivers, _plan("slow"), ctx=_ctx(), log_ctx={}, timeout=0.01 + ) + assert raw is None + assert isinstance(last_error, DriverError) + assert "timed out after 0.01s" in str(last_error) + + +@pytest.mark.asyncio +async def test_timeout_then_fallback_succeeds() -> None: + drivers = {"slow": _SlowDriver("slow"), "fast": _OKDriver("fast", payload="quick")} + raw, driver_id, last_error, fell_back = await execute_with_fallback( + drivers, _plan("slow", "fast"), ctx=_ctx(), log_ctx={}, timeout=0.01 + ) + assert raw is not None and raw.data == "quick" + assert driver_id == "fast" + assert fell_back is True + + +@pytest.mark.asyncio +async def test_fast_driver_under_deadline_succeeds() -> None: + drivers = {"ok": _OKDriver("ok", payload=42)} + raw, driver_id, last_error, fell_back = await execute_with_fallback( + drivers, _plan("ok"), ctx=_ctx(), log_ctx={}, timeout=5.0 + ) + assert raw is not None and raw.data == 42 + assert driver_id == "ok" + assert last_error is None + assert fell_back is False diff --git a/tests/test_drivers.py b/tests/test_drivers.py index 35ba4f4..e9293d1 100644 --- a/tests/test_drivers.py +++ b/tests/test_drivers.py @@ -116,6 +116,47 @@ async def test_billing_driver_summarize(billing_driver: InMemoryDriver) -> None: # ── HTTPDriver ───────────────────────────────────────────────────────────────── +def _mock_http_client( + *, + body: bytes = b"{}", + status_code: int = 200, + is_error: bool = False, + headers: dict[str, str] | None = None, + request_error: Exception | None = None, +) -> MagicMock: + """Build a mock ``httpx.AsyncClient`` whose ``.stream()`` yields a response. + + Mirrors the streaming contract :class:`HTTPDriver` now relies on: an async + context manager whose response exposes ``is_error``, ``aiter_bytes()``, + ``aread()``, ``status_code``, ``headers``, and ``text``. + """ + response = MagicMock() + response.status_code = status_code + response.is_error = is_error + response.headers = headers or {"content-type": "application/json"} + response.text = body.decode("utf-8", "replace") + response.aread = AsyncMock(return_value=body) + + async def _aiter_bytes() -> Any: + # Emit small chunks so the size guard is exercised across boundaries. + for i in range(0, len(body), 8): + yield body[i : i + 8] + + response.aiter_bytes = _aiter_bytes + + stream_cm = MagicMock() + stream_cm.__aenter__ = AsyncMock(return_value=response) + stream_cm.__aexit__ = AsyncMock(return_value=False) + + client = MagicMock() + if request_error is not None: + client.stream = MagicMock(side_effect=request_error) + else: + client.stream = MagicMock(return_value=stream_cm) + client.aclose = AsyncMock() + return client + + def test_httpdriver_register_endpoint() -> None: driver = HTTPDriver(driver_id="myhttp") endpoint = HTTPEndpoint(url="http://example.com/api", method="GET") @@ -124,29 +165,53 @@ def test_httpdriver_register_endpoint() -> None: @pytest.mark.asyncio -async def test_httpdriver_execute_get(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_httpdriver_execute_get() -> None: driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/test", method="GET") - driver.register_endpoint("get_data", endpoint) - - mock_response = MagicMock() - mock_response.json.return_value = [{"id": 1}] - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() + driver.register_endpoint("get_data", HTTPEndpoint(url="http://localhost:9999/test")) + client = _mock_http_client(body=b'[{"id": 1}]') - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client.get = AsyncMock(return_value=mock_response) - - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( - capability_id="cap.x", - principal_id="u1", - args={"operation": "get_data"}, + capability_id="cap.x", principal_id="u1", args={"operation": "get_data"} ) result = await driver.execute(ctx) assert result.data == [{"id": 1}] + assert result.metadata == {"status_code": 200, "url": "http://localhost:9999/test"} + assert client.stream.call_args.args[0] == "GET" + + +@pytest.mark.asyncio +async def test_httpdriver_reuses_pooled_client() -> None: + """The client is built once and reused across invocations (#194).""" + driver = HTTPDriver() + driver.register_endpoint("get_data", HTTPEndpoint(url="http://localhost:9999/test")) + client = _mock_http_client(body=b"{}") + + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client) as ctor: + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "get_data"} + ) + await driver.execute(ctx) + await driver.execute(ctx) + assert ctor.call_count == 1 + assert client.stream.call_count == 2 + + +@pytest.mark.asyncio +async def test_httpdriver_aclose_closes_client() -> None: + """aclose() releases the pooled client and is safe to call twice (#194).""" + driver = HTTPDriver() + driver.register_endpoint("get_data", HTTPEndpoint(url="http://localhost:9999/test")) + client = _mock_http_client(body=b"{}") + + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "get_data"} + ) + await driver.execute(ctx) + await driver.aclose() + await driver.aclose() # idempotent + client.aclose.assert_awaited_once() @pytest.mark.asyncio @@ -158,130 +223,145 @@ async def test_httpdriver_unknown_operation_raises() -> None: @pytest.mark.asyncio -async def test_httpdriver_http_error_raises(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_httpdriver_http_error_raises() -> None: driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/fail", method="GET") - driver.register_endpoint("fail_op", endpoint) + driver.register_endpoint("fail_op", HTTPEndpoint(url="http://localhost:9999/fail")) + client = _mock_http_client(body=b"Internal Server Error", status_code=500, is_error=True) - mock_response = MagicMock() - mock_response.status_code = 500 - mock_response.text = "Internal Server Error" + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "fail_op"} + ) + with pytest.raises(DriverError, match="HTTP 500"): + await driver.execute(ctx) - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - error = httpx.HTTPStatusError("Server Error", request=MagicMock(), response=mock_response) - mock_client.get = AsyncMock(side_effect=error) +@pytest.mark.asyncio +async def test_httpdriver_non_json_response_raises() -> None: + """A 200 with a non-JSON body surfaces as a typed DriverError (#197).""" + driver = HTTPDriver() + driver.register_endpoint("get_html", HTTPEndpoint(url="http://localhost:9999/page")) + client = _mock_http_client( + body=b"not json", headers={"content-type": "text/html"} + ) - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( - capability_id="cap.x", - principal_id="u1", - args={"operation": "fail_op"}, + capability_id="cap.x", principal_id="u1", args={"operation": "get_html"} ) - with pytest.raises(DriverError, match="HTTP 500"): + with pytest.raises(DriverError, match=r"non-JSON response.*content-type: text/html"): await driver.execute(ctx) @pytest.mark.asyncio -async def test_httpdriver_execute_post() -> None: +async def test_httpdriver_text_response_format_returns_string() -> None: + """A ``text`` endpoint returns the decoded body verbatim, no JSON parse (#197).""" driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/items", method="POST") - driver.register_endpoint("create_item", endpoint) + driver.register_endpoint( + "get_text", + HTTPEndpoint(url="http://localhost:9999/page", response_format="text"), + ) + client = _mock_http_client(body=b"plain text body", headers={"content-type": "text/plain"}) - mock_response = MagicMock() - mock_response.json.return_value = {"created": True} - mock_response.status_code = 201 - mock_response.raise_for_status = MagicMock() + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "get_text"} + ) + result = await driver.execute(ctx) + assert result.data == "plain text body" - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client.post = AsyncMock(return_value=mock_response) - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): +@pytest.mark.asyncio +async def test_httpdriver_empty_body_returns_none() -> None: + driver = HTTPDriver() + driver.register_endpoint("get_empty", HTTPEndpoint(url="http://localhost:9999/empty")) + client = _mock_http_client(body=b"") + + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( - capability_id="cap.x", - principal_id="u1", - args={"operation": "create_item", "name": "test"}, + capability_id="cap.x", principal_id="u1", args={"operation": "get_empty"} ) result = await driver.execute(ctx) - assert result.data == {"created": True} - mock_client.post.assert_called_once() + assert result.data is None @pytest.mark.asyncio -async def test_httpdriver_execute_put() -> None: - driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/items/1", method="PUT") - driver.register_endpoint("update_item", endpoint) +async def test_httpdriver_response_size_limit_aborts() -> None: + """A body larger than ``max_response_bytes`` aborts with a DriverError (#194).""" + driver = HTTPDriver(max_response_bytes=10) + driver.register_endpoint("get_big", HTTPEndpoint(url="http://localhost:9999/big")) + client = _mock_http_client(body=b"x" * 100) - mock_response = MagicMock() - mock_response.json.return_value = {"updated": True} - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "get_big"} + ) + with pytest.raises(DriverError, match=r"exceeded max_response_bytes \(10\)"): + await driver.execute(ctx) - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client.put = AsyncMock(return_value=mock_response) - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): +@pytest.mark.asyncio +async def test_httpdriver_response_size_limit_allows_small_body() -> None: + """A body within ``max_response_bytes`` is returned normally (#194).""" + driver = HTTPDriver(max_response_bytes=1000) + driver.register_endpoint("get_small", HTTPEndpoint(url="http://localhost:9999/small")) + client = _mock_http_client(body=b'{"ok": true}') + + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( - capability_id="cap.x", - principal_id="u1", - args={"operation": "update_item", "name": "updated"}, + capability_id="cap.x", principal_id="u1", args={"operation": "get_small"} ) result = await driver.execute(ctx) - assert result.data == {"updated": True} - mock_client.put.assert_called_once() + assert result.data == {"ok": True} @pytest.mark.asyncio -async def test_httpdriver_execute_delete() -> None: +async def test_httpdriver_execute_post_sends_json_body() -> None: driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/items/1", method="DELETE") - driver.register_endpoint("delete_item", endpoint) - - mock_response = MagicMock() - mock_response.json.return_value = {"deleted": True} - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client.delete = AsyncMock(return_value=mock_response) + driver.register_endpoint( + "create_item", HTTPEndpoint(url="http://localhost:9999/items", method="POST") + ) + client = _mock_http_client(body=b'{"created": true}', status_code=201) - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( capability_id="cap.x", principal_id="u1", - args={"operation": "delete_item", "id": "1"}, + args={"operation": "create_item", "name": "test"}, ) result = await driver.execute(ctx) - assert result.data == {"deleted": True} - mock_client.delete.assert_called_once_with("http://localhost:9999/items/1", params={"id": "1"}) + assert result.data == {"created": True} + assert client.stream.call_args.args[0] == "POST" + assert client.stream.call_args.kwargs["json"] == {"name": "test"} @pytest.mark.asyncio -async def test_httpdriver_execute_patch_uses_request() -> None: +async def test_httpdriver_execute_delete_sends_params() -> None: driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/items/1", method="PATCH") - driver.register_endpoint("patch_item", endpoint) + driver.register_endpoint( + "delete_item", HTTPEndpoint(url="http://localhost:9999/items/1", method="DELETE") + ) + client = _mock_http_client(body=b'{"deleted": true}') - mock_response = MagicMock() - mock_response.json.return_value = {"patched": True} - mock_response.status_code = 200 - mock_response.raise_for_status = MagicMock() + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "delete_item", "id": "1"} + ) + result = await driver.execute(ctx) + assert result.data == {"deleted": True} + assert client.stream.call_args.args[0] == "DELETE" + assert client.stream.call_args.kwargs["params"] == {"id": "1"} - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client.request = AsyncMock(return_value=mock_response) - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): +@pytest.mark.asyncio +async def test_httpdriver_execute_patch_uses_method() -> None: + driver = HTTPDriver() + driver.register_endpoint( + "patch_item", HTTPEndpoint(url="http://localhost:9999/items/1", method="PATCH") + ) + client = _mock_http_client(body=b'{"patched": true}') + + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( capability_id="cap.x", principal_id="u1", @@ -289,29 +369,23 @@ async def test_httpdriver_execute_patch_uses_request() -> None: ) result = await driver.execute(ctx) assert result.data == {"patched": True} - mock_client.request.assert_called_once_with( - "PATCH", "http://localhost:9999/items/1", json={"field": "value"} - ) + assert client.stream.call_args.args[0] == "PATCH" + assert client.stream.call_args.kwargs["json"] == {"field": "value"} @pytest.mark.asyncio async def test_httpdriver_request_error_raises() -> None: driver = HTTPDriver() - endpoint = HTTPEndpoint(url="http://localhost:9999/unreachable", method="GET") - driver.register_endpoint("unreachable_op", endpoint) - - mock_client = AsyncMock() - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_client.get = AsyncMock( - side_effect=httpx.ConnectError("Connection refused", request=MagicMock()) + driver.register_endpoint( + "unreachable_op", HTTPEndpoint(url="http://localhost:9999/unreachable") + ) + client = _mock_http_client( + request_error=httpx.ConnectError("Connection refused", request=MagicMock()) ) - with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=mock_client): + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): ctx = ExecutionContext( - capability_id="cap.x", - principal_id="u1", - args={"operation": "unreachable_op"}, + capability_id="cap.x", principal_id="u1", args={"operation": "unreachable_op"} ) with pytest.raises(DriverError, match="Request to .* failed"): await driver.execute(ctx) diff --git a/tests/test_kernel.py b/tests/test_kernel.py index a5b90e1..7dc55cd 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -2,9 +2,12 @@ from __future__ import annotations +import asyncio + import pytest from weaver_kernel import ( + BudgetManager, Capability, CapabilityRegistry, DriverError, @@ -17,7 +20,10 @@ StaticRouter, TokenExpired, ) -from weaver_kernel.models import CapabilityRequest +from weaver_kernel.drivers.base import ExecutionContext +from weaver_kernel.errors import FirewallError +from weaver_kernel.firewall.transform import Firewall +from weaver_kernel.models import CapabilityRequest, RawResult # ── Full flow: request → grant → invoke → expand → explain ───────────────────── @@ -228,6 +234,153 @@ async def test_all_drivers_fail_raises_driver_error() -> None: await k.invoke(token, principal=principal, args={}) +# ── Fail-closed driver execution: audit + budget release (#152, #191) ────────── + + +class _RawRaisingDriver: + """A driver whose ``execute`` raises a non-DriverError (e.g. a library bug).""" + + def __init__(self, driver_id: str = "raw_raiser") -> None: + self._driver_id = driver_id + + @property + def driver_id(self) -> str: + return self._driver_id + + async def execute(self, ctx: ExecutionContext) -> RawResult: + raise ValueError("unexpected library failure") + + +class _SlowDriver: + """A driver that sleeps past any reasonable per-invocation deadline.""" + + def __init__(self, driver_id: str = "slow") -> None: + self._driver_id = driver_id + + @property + def driver_id(self) -> str: + return self._driver_id + + async def execute(self, ctx: ExecutionContext) -> RawResult: + await asyncio.sleep(1.0) + return RawResult(capability_id=ctx.capability_id, data="late") + + +class _BoomFirewall(Firewall): + """A Firewall whose ``transform`` always fails — to exercise the post-driver path.""" + + def transform(self, *args: object, **kwargs: object): # type: ignore[override] + raise FirewallError("transform boom") + + +def _single_driver_kernel(driver: object, *, budget: int | None = None) -> Kernel: + """Build a kernel routing ``cap`` to *driver*, optionally with a budget.""" + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="cap", + name="Cap", + description="A capability under test", + safety_class=SafetyClass.READ, + ) + ) + return Kernel( + registry=registry, + router=StaticRouter(routes={"cap": [driver.driver_id]}), # type: ignore[attr-defined] + token_provider=HMACTokenProvider(secret="test-secret"), + budget_manager=BudgetManager(total_budget=budget) if budget is not None else None, + ) + + +@pytest.mark.asyncio +async def test_non_driver_error_is_audited_and_budget_released() -> None: + """A driver raising a non-DriverError is audited and the reservation freed (#152).""" + driver = _RawRaisingDriver() + k = _single_driver_kernel(driver, budget=10_000) + k.register_driver(driver) + assert k.budget is not None + before = k.budget.remaining + + principal = Principal(principal_id="u1") + token = k._token_provider.issue("cap", "u1") + with pytest.raises(DriverError, match="All drivers failed"): + await k.invoke(token, principal=principal, args={}) + + # Budget reservation released (no leak), and a failure trace was recorded. + assert k.budget.remaining == before + traces = k.list_traces() + assert len(traces) == 1 + assert traces[0].driver_id == "" + assert "unexpected library failure" in (traces[0].error or "") + + +@pytest.mark.asyncio +async def test_firewall_failure_is_audited_and_budget_released() -> None: + """A post-driver firewall failure is audited and the reservation freed (#152).""" + driver = InMemoryDriver(driver_id="ok") + driver.register_handler("cap", lambda ctx: {"value": 1}) + registry = CapabilityRegistry() + registry.register( + Capability( + capability_id="cap", + name="Cap", + description="A capability under test", + safety_class=SafetyClass.READ, + ) + ) + k = Kernel( + registry=registry, + router=StaticRouter(routes={"cap": ["ok"]}), + token_provider=HMACTokenProvider(secret="test-secret"), + firewall=_BoomFirewall(), + budget_manager=BudgetManager(total_budget=10_000), + ) + k.register_driver(driver) + assert k.budget is not None + before = k.budget.remaining + + principal = Principal(principal_id="u1") + token = k._token_provider.issue("cap", "u1") + with pytest.raises(FirewallError, match="transform boom"): + await k.invoke(token, principal=principal, args={}) + + assert k.budget.remaining == before + traces = k.list_traces() + assert len(traces) == 1 + assert "transform boom" in (traces[0].error or "") + + +@pytest.mark.asyncio +async def test_invoke_timeout_constraint_aborts_slow_driver() -> None: + """A per-grant ``invoke_timeout_s`` constraint bounds execution (#191).""" + driver = _SlowDriver() + k = _single_driver_kernel(driver) + k.register_driver(driver) + + principal = Principal(principal_id="u1") + token = k._token_provider.issue("cap", "u1", constraints={"invoke_timeout_s": 0.01}) + with pytest.raises(DriverError, match="timed out after 0.01s"): + await k.invoke(token, principal=principal, args={}) + + traces = k.list_traces() + assert len(traces) == 1 + assert "timed out" in (traces[0].error or "") + + +@pytest.mark.asyncio +async def test_invalid_invoke_timeout_constraint_rejected() -> None: + """A malformed ``invoke_timeout_s`` constraint fails fast with a typed error (#191).""" + driver = InMemoryDriver(driver_id="ok") + driver.register_handler("cap", lambda ctx: {"value": 1}) + k = _single_driver_kernel(driver) + k.register_driver(driver) + + principal = Principal(principal_id="u1") + token = k._token_provider.issue("cap", "u1", constraints={"invoke_timeout_s": -5}) + with pytest.raises(DriverError, match="invoke_timeout_s"): + await k.invoke(token, principal=principal, args={}) + + # ── Confused-deputy prevention ───────────────────────────────────────────────── From b8555bcc103a5aa854a3cff4e259513026470aa1 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 25 Jun 2026 05:48:22 +0000 Subject: [PATCH 2/3] =?UTF-8?q?fix:=20address=20Copilot=20review=20on=20#2?= =?UTF-8?q?38=20=E2=80=94=20bound=20error=20reads,=20audit=20fidelity,=20c?= =?UTF-8?q?ancellation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolve the 5 review findings on the fail-closed driver-execution PR: - http.py: the error path called `response.aread()` unconditionally, buffering an arbitrarily large error body and bypassing `max_response_bytes`. Read a bounded snippet (≤512 bytes) for the message instead, so the size guard holds on the failure path too (#194). - _invoke.py: failure traces now record the *effective* response mode (matching success and streaming traces) instead of the caller-requested mode, and carry the implicated `driver_id` (the driver that ran and failed downstream, or the last attempted when all failed) — previously hard-coded to "". `record_failure_trace` gained an optional `driver_id`; `execute_with_fallback` now returns the last-attempted driver on failure. - _invoke.py: task cancellation during execution no longer leaks the budget reservation or skips the audit trace. The reservation is released in a single `finally` (covering `CancelledError`, which is not an `Exception`) and a "invocation cancelled" failure trace is recorded before propagating — honouring the PR's "audited on every exit" goal. - _stream.py: the stream trace records the real failure reason (e.g. a timeout `DriverError`), redacted, instead of always None / the generic "stream produced no chunks". Tests: cancellation-mid-invoke (budget released + audited), stream-timeout error-reason trace, bounded error-body read; updated existing assertions for the new driver_id attribution. `make ci` green (792 passed, 94.01%). Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_018VdP3irZSbbPyyoVS1QQbD --- src/weaver_kernel/drivers/http.py | 27 +++++++- src/weaver_kernel/kernel/_driver_exec.py | 19 ++++-- src/weaver_kernel/kernel/_invoke.py | 81 +++++++++++++----------- src/weaver_kernel/kernel/_stream.py | 23 ++++++- tests/test_architecture.py | 2 +- tests/test_driver_exec.py | 2 +- tests/test_drivers.py | 17 +++++ tests/test_kernel.py | 46 +++++++++++++- 8 files changed, 165 insertions(+), 52 deletions(-) diff --git a/src/weaver_kernel/drivers/http.py b/src/weaver_kernel/drivers/http.py index b3e502a..db751b3 100644 --- a/src/weaver_kernel/drivers/http.py +++ b/src/weaver_kernel/drivers/http.py @@ -144,10 +144,12 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: timeout=effective_timeout, ) as response: if response.is_error: - await response.aread() + # Bound the error-body read too: an arbitrarily large error + # body must not be buffered just to build the message (#194). + snippet = await self._read_error_snippet(response) raise DriverError( f"HTTPDriver '{self._driver_id}': HTTP {response.status_code} " - f"from {endpoint.url}: {response.text[:200]}" + f"from {endpoint.url}: {snippet}" ) body = await self._read_bounded(response, url=endpoint.url) status_code = response.status_code @@ -169,6 +171,27 @@ async def execute(self, ctx: ExecutionContext) -> RawResult: metadata={"status_code": status_code, "url": endpoint.url}, ) + async def _read_error_snippet(self, response: httpx.Response, *, max_bytes: int = 512) -> str: + """Read at most ``max_bytes`` of an error body for the failure message. + + Streams and stops early so an oversized error body cannot be buffered in + full — the size guard must hold on the failure path too (#194). Only the + first 200 characters are surfaced in the error message. + + Args: + response: The open streaming response (already known to be an error). + max_bytes: Hard cap on bytes read before giving up. + + Returns: + A decoded, length-bounded snippet of the error body. + """ + chunks = bytearray() + async for chunk in response.aiter_bytes(): + chunks.extend(chunk) + if len(chunks) >= max_bytes: + break + return bytes(chunks).decode("utf-8", "replace")[:200] + async def _read_bounded(self, response: httpx.Response, *, url: str) -> bytes: """Read the response body, aborting if it exceeds ``max_response_bytes``. diff --git a/src/weaver_kernel/kernel/_driver_exec.py b/src/weaver_kernel/kernel/_driver_exec.py index b62ee8c..d8ee14d 100644 --- a/src/weaver_kernel/kernel/_driver_exec.py +++ b/src/weaver_kernel/kernel/_driver_exec.py @@ -83,12 +83,15 @@ async def execute_with_fallback( ``DriverError`` and treated as a failed attempt. Returns: - ``(raw_result, driver_id, last_error, fell_back)``. ``raw_result`` is - ``None`` if every driver failed; ``fell_back`` is ``True`` when at least - one earlier driver raised before the one that ultimately ran (or before - all-failed), so callers can count fallback activations. A route entry - whose driver is unregistered (``drivers.get(driver_id) is None``) is - skipped silently and does **not** set ``fell_back``. + ``(raw_result, driver_id, last_error, fell_back)``. On success + ``driver_id`` is the driver that ran; on failure (``raw_result is + None``) it is the **last driver attempted** (``""`` only if every route + entry was unregistered), so callers can attribute the failure in the + audit trace. ``fell_back`` is ``True`` when at least one earlier driver + raised before the one that ultimately ran (or before all-failed), so + callers can count fallback activations. A route entry whose driver is + unregistered (``drivers.get(driver_id) is None``) is skipped silently + and does **not** set ``fell_back``. A ``DriverError`` *and* any other exception a driver raises both count as a failed attempt (#152); the latter is preserved as ``last_error`` @@ -96,11 +99,13 @@ async def execute_with_fallback( rather than letting it escape un-audited. """ last_error: Exception | None = None + last_driver_id = "" failed_attempts = 0 for driver_id in plan.driver_ids: driver = drivers.get(driver_id) if driver is None: continue + last_driver_id = driver_id try: if timeout is None: raw_result = await driver.execute(ctx) @@ -145,7 +150,7 @@ async def execute_with_fallback( last_error = exc failed_attempts += 1 continue - return None, "", last_error, failed_attempts > 0 + return None, last_driver_id, last_error, failed_attempts > 0 __all__ = ["execute_with_fallback", "resolve_invoke_timeout", "INVOKE_TIMEOUT_CONSTRAINT"] diff --git a/src/weaver_kernel/kernel/_invoke.py b/src/weaver_kernel/kernel/_invoke.py index 0e97bb4..81118fc 100644 --- a/src/weaver_kernel/kernel/_invoke.py +++ b/src/weaver_kernel/kernel/_invoke.py @@ -15,6 +15,7 @@ from __future__ import annotations +import asyncio import datetime import logging import uuid @@ -141,8 +142,15 @@ def record_failure_trace( error_message: str, trace_store: TraceStoreProtocol, sensitivity: SensitivityTag = SensitivityTag.NONE, + driver_id: str = "", ) -> None: - """Persist an :class:`ActionTrace` for a run where no driver succeeded.""" + """Persist an :class:`ActionTrace` for a failed run. + + Args: + driver_id: The driver implicated in the failure — the one that ran and + then failed downstream, or the last one attempted when every driver + failed. Empty only when no driver was reached. + """ trace_store.record( ActionTrace( action_id=action_id, @@ -152,7 +160,7 @@ def record_failure_trace( invoked_at=datetime.datetime.now(tz=datetime.timezone.utc), args=_redact_args_for_trace(capability_id, args), response_mode=response_mode, - driver_id="", + driver_id=driver_id, sensitivity=sensitivity, error=_redact_trace_text(error_message), ) @@ -257,39 +265,46 @@ async def perform_invoke( action_id=action_id, ) downgraded = effective_mode != response_mode - raw_result, used_driver_id, last_error, fell_back = await execute_with_fallback( - kernel._driver_map, plan, ctx=ctx, log_ctx=log_ctx, timeout=invoke_timeout - ) + used_driver_id = "" + fell_back = False - if raw_result is None: - if kernel.budget is not None and reserved_tokens is not None: - await kernel.budget.release(reserved_tokens) - err_msg = str(last_error) if last_error else "No drivers available." - logger.warning("invoke_failure", extra={**log_ctx, "error": err_msg}) + def _record_invoke_failure(message: str, driver_id: str) -> None: + """Log + audit a failed invocation, recording the effective mode (#152).""" + logger.warning("invoke_failure", extra={**log_ctx, "error": message}) record_failure_trace( action_id=action_id, capability_id=token.capability_id, principal_id=principal.principal_id, token_id=token.token_id, args=args, - response_mode=response_mode, - error_message=err_msg, + response_mode=effective_mode, + error_message=message, trace_store=kernel._traces, sensitivity=capability.sensitivity, + driver_id=driver_id, ) kernel._stats.on_invocation( failed=True, fallback=fell_back, redacted=False, downgraded=downgraded ) - raise DriverError( - f"All drivers failed for capability '{token.capability_id}'. Last error: {err_msg}" - ) - # I-02: faults *after* the driver returned — handle creation, firewall - # transform, token counting, usage accounting — must not escape un-audited - # or leak the reservation. Capture any escape, release the budget exactly - # once, and record a failure trace before re-raising (#152). + # I-02: every exit past the budget reservation — a driver failure, a fault + # in the post-driver pipeline (handle creation, firewall transform, token + # counting), or task cancellation — must release the reservation exactly + # once and leave an audit trace. The reservation is freed in a single + # ``finally`` (which also covers ``CancelledError``, not an ``Exception``); + # each path records its own failure trace before propagating (#152, #191). handle: Handle | None = None + reservation_settled = False try: + raw_result, used_driver_id, last_error, fell_back = await execute_with_fallback( + kernel._driver_map, plan, ctx=ctx, log_ctx=log_ctx, timeout=invoke_timeout + ) + if raw_result is None: + err_msg = str(last_error) if last_error else "No drivers available." + _record_invoke_failure(err_msg, used_driver_id) + raise DriverError( + f"All drivers failed for capability '{token.capability_id}'. Last error: {err_msg}" + ) if effective_mode != "raw": handle = kernel._handles.store( capability_id=token.capability_id, @@ -310,26 +325,18 @@ async def perform_invoke( if kernel.budget is not None and reserved_tokens is not None: actual_tokens = kernel.budget.count_tokens(_frame_payload(frame)) await kernel.budget.record_usage(actual_tokens, reserved=reserved_tokens) + reservation_settled = True # consumed via record_usage (or no budget configured) + except DriverError: + raise # already audited by _record_invoke_failure above + except asyncio.CancelledError: + _record_invoke_failure("invocation cancelled", used_driver_id) + raise except Exception as exc: - if kernel.budget is not None and reserved_tokens is not None: - await kernel.budget.release(reserved_tokens) - err_msg = str(exc) - logger.warning("invoke_failure", extra={**log_ctx, "error": err_msg}) - record_failure_trace( - action_id=action_id, - capability_id=token.capability_id, - principal_id=principal.principal_id, - token_id=token.token_id, - args=args, - response_mode=response_mode, - error_message=err_msg, - trace_store=kernel._traces, - sensitivity=capability.sensitivity, - ) - kernel._stats.on_invocation( - failed=True, fallback=fell_back, redacted=False, downgraded=downgraded - ) + _record_invoke_failure(str(exc), used_driver_id) raise + finally: + if not reservation_settled and kernel.budget is not None and reserved_tokens is not None: + await kernel.budget.release(reserved_tokens) record_success_trace( action_id=action_id, diff --git a/src/weaver_kernel/kernel/_stream.py b/src/weaver_kernel/kernel/_stream.py index 086eb3f..072dab1 100644 --- a/src/weaver_kernel/kernel/_stream.py +++ b/src/weaver_kernel/kernel/_stream.py @@ -36,7 +36,12 @@ ) from ..tokens import CapabilityToken from ._driver_exec import resolve_invoke_timeout -from ._invoke import _frame_result_summary, _redact_args_for_trace, resolve_effective_mode +from ._invoke import ( + _frame_result_summary, + _redact_args_for_trace, + _redact_trace_text, + resolve_effective_mode, +) if TYPE_CHECKING: # pragma: no cover from . import Kernel @@ -102,6 +107,7 @@ async def invoke_stream_impl( redacted_any = False handle: Handle | None = None last_frame: Frame | None = None + stream_error: str | None = None try: if streaming_driver is not None: async for frame in _stream_chunks( @@ -155,7 +161,20 @@ async def invoke_stream_impl( redacted_any = redacted_any or bool(frame.warnings) last_frame = frame yield frame + except Exception as exc: + # Preserve the real failure reason (e.g. a timeout DriverError) so the + # trace records *why* the stream ended, not just that it produced no + # chunks (#191). + stream_error = str(exc) + raise finally: + error: str | None + if stream_error is not None: + error = _redact_trace_text(stream_error) + elif yielded_any: + error = None + else: + error = "stream produced no chunks" kernel._traces.record( ActionTrace( action_id=action_id, @@ -169,7 +188,7 @@ async def invoke_stream_impl( sensitivity=capability.sensitivity, handle_id=handle.handle_id if handle else None, result_summary=(_frame_result_summary(last_frame) if last_frame else None), - error=None if yielded_any else "stream produced no chunks", + error=error, ) ) kernel._stats.on_invocation( diff --git a/tests/test_architecture.py b/tests/test_architecture.py index 7c55630..afd428a 100644 --- a/tests/test_architecture.py +++ b/tests/test_architecture.py @@ -52,7 +52,7 @@ "policy.py": 652, "kernel/__init__.py": 541, "adapters/_base.py": 459, - "kernel/_invoke.py": 382, + "kernel/_invoke.py": 390, "firewall/transform.py": 377, "adapters/openai.py": 358, "stores/sqlite.py": 350, diff --git a/tests/test_driver_exec.py b/tests/test_driver_exec.py index cd01ef2..4cfe5bd 100644 --- a/tests/test_driver_exec.py +++ b/tests/test_driver_exec.py @@ -94,7 +94,7 @@ async def test_unexpected_exception_is_captured_not_raised() -> None: drivers, _plan("bad"), ctx=_ctx(), log_ctx={} ) assert raw is None - assert driver_id == "" + assert driver_id == "bad" # last driver attempted, for audit attribution assert last_error is boom assert fell_back is True diff --git a/tests/test_drivers.py b/tests/test_drivers.py index e9293d1..1a81203 100644 --- a/tests/test_drivers.py +++ b/tests/test_drivers.py @@ -236,6 +236,23 @@ async def test_httpdriver_http_error_raises() -> None: await driver.execute(ctx) +@pytest.mark.asyncio +async def test_httpdriver_error_body_is_bounded() -> None: + """An oversized error body is not buffered in full; the message is bounded (#194).""" + driver = HTTPDriver(max_response_bytes=10) + driver.register_endpoint("fail_big", HTTPEndpoint(url="http://localhost:9999/fail")) + client = _mock_http_client(body=b"E" * 100_000, status_code=500, is_error=True) + + with patch("weaver_kernel.drivers.http.httpx.AsyncClient", return_value=client): + ctx = ExecutionContext( + capability_id="cap.x", principal_id="u1", args={"operation": "fail_big"} + ) + with pytest.raises(DriverError, match="HTTP 500") as excinfo: + await driver.execute(ctx) + # Only a bounded snippet is surfaced, regardless of the 100 KB body. + assert len(str(excinfo.value)) < 300 + + @pytest.mark.asyncio async def test_httpdriver_non_json_response_raises() -> None: """A 200 with a non-JSON body surfaces as a typed DriverError (#197).""" diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 7dc55cd..367763d 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -306,11 +306,12 @@ async def test_non_driver_error_is_audited_and_budget_released() -> None: with pytest.raises(DriverError, match="All drivers failed"): await k.invoke(token, principal=principal, args={}) - # Budget reservation released (no leak), and a failure trace was recorded. + # Budget reservation released (no leak), and a failure trace was recorded + # attributing the last-attempted driver (#152). assert k.budget.remaining == before traces = k.list_traces() assert len(traces) == 1 - assert traces[0].driver_id == "" + assert traces[0].driver_id == "raw_raiser" assert "unexpected library failure" in (traces[0].error or "") @@ -381,6 +382,47 @@ async def test_invalid_invoke_timeout_constraint_rejected() -> None: await k.invoke(token, principal=principal, args={}) +@pytest.mark.asyncio +async def test_cancellation_is_audited_and_budget_released() -> None: + """Cancelling an in-flight invoke releases the reservation and is audited (#152).""" + driver = _SlowDriver() + k = _single_driver_kernel(driver, budget=10_000) + k.register_driver(driver) + assert k.budget is not None + before = k.budget.remaining + + principal = Principal(principal_id="u1") + token = k._token_provider.issue("cap", "u1") + task = asyncio.create_task(k.invoke(token, principal=principal, args={})) + await asyncio.sleep(0.05) # let the task reach the (slow) driver + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert k.budget.remaining == before + traces = k.list_traces() + assert len(traces) == 1 + assert "cancelled" in (traces[0].error or "") + + +@pytest.mark.asyncio +async def test_stream_timeout_records_error_reason() -> None: + """A streaming timeout records the real failure reason, not a generic note (#191).""" + driver = _SlowDriver() + k = _single_driver_kernel(driver) + k.register_driver(driver) + + principal = Principal(principal_id="u1") + token = k._token_provider.issue("cap", "u1", constraints={"invoke_timeout_s": 0.01}) + with pytest.raises(DriverError, match="timed out"): + async for _ in k.invoke_stream(token, principal=principal, args={}): + pass + + traces = k.list_traces() + assert len(traces) == 1 + assert "timed out" in (traces[0].error or "") + + # ── Confused-deputy prevention ───────────────────────────────────────────────── From 057f4107ce4117677292eeb1864829bf027d1c25 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 25 Jun 2026 05:51:00 +0000 Subject: [PATCH 3/3] test: silence CodeQL ineffectual-statement on bare `await task` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CodeQL flagged `await task` (awaiting a cancelled task to re-raise CancelledError) as a statement with no effect — a false positive on a bare await expression statement. Bind the result to `_` so the analyzer sees an effectful assignment; behaviour is unchanged. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_018VdP3irZSbbPyyoVS1QQbD --- tests/test_kernel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_kernel.py b/tests/test_kernel.py index 367763d..c163075 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -397,7 +397,9 @@ async def test_cancellation_is_audited_and_budget_released() -> None: await asyncio.sleep(0.05) # let the task reach the (slow) driver task.cancel() with pytest.raises(asyncio.CancelledError): - await task + # Awaiting the cancelled task re-raises CancelledError — that *is* the + # effect under test (assigned to ``_`` so static analysis sees it). + _ = await task assert k.budget.remaining == before traces = k.list_traces()