From 13ee1c8156ccffd4756b3b5f2ac4795c28b0686a Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 22 Jun 2026 20:10:23 +0000 Subject: [PATCH 1/6] feat(executor): decision-callback audit trail + policy controls (#369, #370) Record guided decision points on StepRecord.decision (DecisionRecord: candidates, chosen, default_tool_name, duration_ms, timed_out) and downgrade flows with decision_candidates steps to DeterminismLevel.PARTIAL (#369). Add opt-in DecisionPolicy (timeout_s, max_decisions_per_flow, on_timeout) enforced via a bounded-join worker thread, with new DecisionTimeoutError (CW-E049) and DecisionBudgetExceededError (CW-E050) (#370). decision_policy=None keeps behavior unchanged. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_011U2ZKg2GMec7iAUDt2zezq --- AGENTS.md | 5 +- CHANGELOG.md | 30 +++ chainweaver/__init__.py | 8 + chainweaver/decisions.py | 96 +++++++- chainweaver/exceptions.py | 52 +++++ chainweaver/executor.py | 312 +++++++++++++++++++------ chainweaver/flow.py | 23 +- tests/fixtures/public_api.json | 41 +++- tests/test_decision_audit_policy.py | 351 ++++++++++++++++++++++++++++ 9 files changed, 832 insertions(+), 86 deletions(-) create mode 100644 tests/test_decision_audit_policy.py diff --git a/AGENTS.md b/AGENTS.md index b6bfcc5..743e63c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -145,7 +145,7 @@ benchmarks/ Standalone benchmark scripts (not coverage-gated): ### Key entry points -- `FlowExecutor(..., decision_callback=...)` → wire a `DecisionCallback` for guided decision points (#102); steps with `decision_candidates` set call the callback to pick which tool to run. Either a class with `decide(ctx)` or a bare callable is accepted (coerced via `coerce_decision_callback`). +- `FlowExecutor(..., decision_callback=...)` → wire a `DecisionCallback` for guided decision points (#102); steps with `decision_candidates` set call the callback to pick which tool to run. Either a class with `decide(ctx)` or a bare callable is accepted (coerced via `coerce_decision_callback`). Each resolution is recorded on `StepRecord.decision` (`DecisionRecord`, #369). Pass `decision_policy=DecisionPolicy(timeout_s=..., max_decisions_per_flow=..., on_timeout=...)` to bound callback latency and per-flow decision count (#370). - `KernelBackedExecutor(..., kernel=...)` from `chainweaver.integrations.agent_kernel` (#89) → optional `FlowExecutor` subclass that delegates `DAGFlowStep` instances with `step_type="capability"` through a `KernelProtocol`. The base `FlowExecutor` rejects capability steps; only this subclass dispatches them. - `flow_to_selectable_item(flow, *, capability_id=None, tags=())` from `chainweaver.integrations.weaver_spec` (#107) → project a `Flow` or `DAGFlow` to a weaver-spec `SelectableItem` for contextweaver catalog ingestion. - `RoutingDecisionAdapter(client=...)` from `chainweaver.integrations.contextweaver` (#106) → `DecisionCallback` impl that asks a `ContextweaverClient` for a `RoutingDecision` and returns the selected capability id. @@ -276,6 +276,9 @@ needing the new state must re-fetch via `get_flow`. `type[BaseModel] | None`. `determinism_level` is a computed `DeterminismLevel` (#8): linear `Flow` → `FULL` (or `NONE` if `deterministic=False`); `DAGFlow` with any conditional `branches` → `PARTIAL`. +Any step (linear or DAG) with non-empty `decision_candidates` (#102) also +downgrades the flow to `PARTIAL` (#369), since a registered callback can pick +a different tool per run. ### `DAGFlowStep` conditional branching (#9) diff --git a/CHANGELOG.md b/CHANGELOG.md index 178a63d..4aa6537 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,36 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Decision-callback audit trail and policy controls** (#369, #370): + guided decision points (#102) are now visible in traces and bounded by + opt-in guardrails. + - *Audit record* (#369): `StepRecord.decision` carries a new + `chainweaver.decisions.DecisionRecord` (`candidates`, `chosen`, + `default_tool_name`, `duration_ms`, `timed_out`), populated **exactly + when** a registered `DecisionCallback` resolves a step. It round-trips + through JSON and is `None` for ordinary steps and for static fallbacks + where no callback ran. + - *Policy controls* (#370): `FlowExecutor(..., decision_policy=DecisionPolicy(...))` + adds a per-decision `timeout_s` (the callback runs on a bounded-join + worker thread; `on_timeout="error"` fails the step with the new + `DecisionTimeoutError` / `CW-E049`, `on_timeout="default"` falls back to + the step's static `tool_name` and records `timed_out=True`) and a + per-flow `max_decisions_per_flow` budget (exceeding it aborts the run + with `DecisionBudgetExceededError` / `CW-E050`). Sub-flows carry their + own independent budget. With `decision_policy=None` (the default), + behavior is unchanged. + +### Changed + +- **Determinism reclassification for decision-bearing flows** (#369, + breaking): a linear `Flow` or `DAGFlow` containing any step with non-empty + `decision_candidates` now reports `DeterminismLevel.PARTIAL` instead of + `FULL` (or `NONE` when `deterministic=False`), matching the existing + `branches` precedent — a registered callback can select different tools on + different runs, so the executed path is data-dependent. Consumers gating on + `FULL` (catalog exporters, governance policies, attestation) will see these + flows reclassified; this is the corrected signal. + - **OpenCode integration** (#276, #277, #278, #279, #280, #282): observe → suggest → compile → expose for OpenCode, end to end and reversible. - *Trace adapter* (#278/#276): `chainweaver.opencode.normalize_opencode_event` diff --git a/chainweaver/__init__.py b/chainweaver/__init__.py index 70c3f44..0a5fce6 100644 --- a/chainweaver/__init__.py +++ b/chainweaver/__init__.py @@ -96,6 +96,8 @@ DecisionCallable, DecisionCallback, DecisionContext, + DecisionPolicy, + DecisionRecord, coerce_decision_callback, ) from chainweaver.decorators import tool @@ -113,7 +115,9 @@ ContribError, CostProfileError, DAGDefinitionError, + DecisionBudgetExceededError, DecisionCallbackError, + DecisionTimeoutError, FlowAlreadyExistsError, FlowAuthenticationError, FlowAuthorizationError, @@ -341,10 +345,14 @@ "DAGDefinitionError", "DAGFlow", "DAGFlowStep", + "DecisionBudgetExceededError", "DecisionCallable", "DecisionCallback", "DecisionCallbackError", "DecisionContext", + "DecisionPolicy", + "DecisionRecord", + "DecisionTimeoutError", "DeterminismLevel", "DraftFlow", "DriftInfo", diff --git a/chainweaver/decisions.py b/chainweaver/decisions.py index ff5b29f..cbdfd23 100644 --- a/chainweaver/decisions.py +++ b/chainweaver/decisions.py @@ -59,9 +59,9 @@ def pick_first(ctx: DecisionContext) -> str: from __future__ import annotations from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator class DecisionContext(BaseModel): @@ -94,6 +94,87 @@ class DecisionContext(BaseModel): context: dict[str, Any] +class DecisionRecord(BaseModel): + """Audit record of a resolved guided decision point (issue #369). + + Populated on a :class:`~chainweaver.executor.StepRecord` **exactly when** + a registered :class:`DecisionCallback` resolved the step — i.e. the step + declared :attr:`~chainweaver.flow.FlowStep.decision_candidates` *and* a + callback was configured. Steps with no candidates, or candidates but no + registered callback (the static-fallback case), carry ``decision=None``. + + Attributes: + candidates: The non-empty candidate tool names the callback chose + from, in declaration order. + chosen: The tool name the callback selected (always one of + ``candidates``). Equals ``default_tool_name`` when the callback + re-selected the static default, or when an ``on_timeout="default"`` + policy fell back after a timeout. + default_tool_name: The step's static ``tool_name`` — what would have + run with no callback registered. + duration_ms: Wall-clock time spent inside the callback, in + milliseconds, measured with :func:`time.perf_counter`. + timed_out: ``True`` when the callback exceeded + :attr:`DecisionPolicy.timeout_s` and an ``on_timeout="default"`` + policy fell back to ``default_tool_name`` (issue #370). ``False`` + for a normal resolution. + """ + + model_config = ConfigDict(frozen=True) + + candidates: list[str] + chosen: str + default_tool_name: str + duration_ms: float + timed_out: bool = False + + +class DecisionPolicy(BaseModel): + """Opt-in guardrails for the decision-callback seam (issue #370). + + Decision callbacks are the one runtime seam where a flow's pace depends + on external code — and, via the contextweaver adapter, potentially on a + remote routing service. A :class:`DecisionPolicy` bounds that seam with + a per-decision timeout and a per-flow decision budget so a slow or + runaway selector cannot stall or dominate a run. + + All controls are opt-in; the executor's default (``decision_policy=None``) + leaves decision behavior byte-for-byte unchanged. + + Attributes: + timeout_s: Maximum wall-clock seconds a single ``decide`` call may + take. The callback runs on a bounded-join worker thread (so the + executor itself adds no network/LLM behavior); a call that + overruns is handled per :attr:`on_timeout`. ``None`` (the + default) imposes no per-decision timeout. + max_decisions_per_flow: Ceiling on how many decision callbacks a + single flow execution may invoke. Exceeding it aborts the flow + with :class:`~chainweaver.exceptions.DecisionBudgetExceededError`. + ``None`` (the default) imposes no budget. Sub-flows run in their + own execution scope and so carry their own independent budget. + on_timeout: What to do when ``timeout_s`` is exceeded. ``"error"`` + (the correctness-first default) fails the step with + :class:`~chainweaver.exceptions.DecisionTimeoutError`; ``"default"`` + falls back to the step's static ``tool_name`` and records the + fallback on the :class:`DecisionRecord` (``timed_out=True``). + """ + + model_config = ConfigDict(frozen=True) + + timeout_s: float | None = None + max_decisions_per_flow: int | None = None + on_timeout: Literal["error", "default"] = "error" + + @model_validator(mode="after") + def _check_bounds(self) -> DecisionPolicy: + """Reject non-positive timeouts and budgets at construction time.""" + if self.timeout_s is not None and self.timeout_s <= 0: + raise ValueError("DecisionPolicy.timeout_s must be positive when set.") + if self.max_decisions_per_flow is not None and self.max_decisions_per_flow < 1: + raise ValueError("DecisionPolicy.max_decisions_per_flow must be >= 1 when set.") + return self + + @runtime_checkable class DecisionCallback(Protocol): """Structural protocol for runtime tool-narrowing callbacks (issue #102). @@ -191,3 +272,14 @@ def coerce_decision_callback( f"decision_callback must implement DecisionCallback or be callable; " f"got {type(cb).__name__}." ) + + +__all__ = [ + "BaseDecisionCallback", + "DecisionCallable", + "DecisionCallback", + "DecisionContext", + "DecisionPolicy", + "DecisionRecord", + "coerce_decision_callback", +] diff --git a/chainweaver/exceptions.py b/chainweaver/exceptions.py index 02e75ed..09ef65a 100644 --- a/chainweaver/exceptions.py +++ b/chainweaver/exceptions.py @@ -743,6 +743,56 @@ def __init__(self, tool_name: str, step_index: int, detail: str) -> None: ) +class DecisionTimeoutError(ChainWeaverError): + """Raised when a decision callback overruns its timeout (issue #370). + + Only raised when a :class:`~chainweaver.decisions.DecisionPolicy` with + ``on_timeout="error"`` is active and the callback's ``decide`` call + exceeds :attr:`~chainweaver.decisions.DecisionPolicy.timeout_s`. Under + ``on_timeout="default"`` the executor falls back to the step's static + ``tool_name`` instead of raising. + + The orphaned callback thread cannot be force-killed and may complete in + the background; its late return is discarded. + + Attributes: + tool_name: The step's static ``tool_name`` at the decision point. + step_index: Zero-based position of the step inside the flow. + timeout_s: The configured per-decision timeout, in seconds. + """ + + def __init__(self, tool_name: str, step_index: int, timeout_s: float) -> None: + self.tool_name = tool_name + self.step_index = step_index + self.timeout_s = timeout_s + super().__init__( + f"Decision callback for step {step_index} (default tool '{tool_name}') " + f"exceeded the {timeout_s}s timeout." + ) + + +class DecisionBudgetExceededError(ChainWeaverError): + """Raised when a flow exceeds its decision budget (issue #370). + + Raised when a :class:`~chainweaver.decisions.DecisionPolicy` sets + ``max_decisions_per_flow`` and the running flow attempts more decision + callbacks than that ceiling allows. Unlike a callback failure (which + aborts a single step), exceeding the budget aborts the whole flow run. + + Attributes: + flow_name: Name of the flow that exhausted its decision budget. + budget: The configured ``max_decisions_per_flow`` ceiling. + """ + + def __init__(self, flow_name: str, budget: int) -> None: + self.flow_name = flow_name + self.budget = budget + super().__init__( + f"Flow '{flow_name}' exceeded its decision budget of {budget} " + f"decision callback(s) per execution." + ) + + class KernelInvocationError(ChainWeaverError): """Raised when a :class:`~chainweaver.integrations.agent_kernel.KernelBackedExecutor` cannot dispatch a capability step (issue #89). @@ -985,6 +1035,8 @@ def __init__(self, predicate: str, detail: str) -> None: FlowAuthenticationError: "CW-E045", RateLimitExceededError: "CW-E046", FlowAuthorizationError: "CW-E047", + DecisionTimeoutError: "CW-E049", + DecisionBudgetExceededError: "CW-E050", } for _exc_cls, _exc_code in _ERROR_CODES.items(): diff --git a/chainweaver/executor.py b/chainweaver/executor.py index 8908874..3370237 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -53,17 +53,22 @@ DecisionCallable, DecisionCallback, DecisionContext, + DecisionPolicy, + DecisionRecord, coerce_decision_callback, ) from chainweaver.events import FlowEvent from chainweaver.exceptions import ( ApprovalDeniedError, AsyncLaneUnsupportedError, + ChainWeaverError, CheckpointDriftError, CheckpointerNotConfiguredError, CheckpointNotFoundError, CheckpointVersionError, + DecisionBudgetExceededError, DecisionCallbackError, + DecisionTimeoutError, FlowCancelledError, FlowCompositionError, FlowExecutionError, @@ -181,6 +186,12 @@ def __init__(self) -> None: # for the duration of an ``execute_flow(dynamic_params=...)`` call and # inherited by sub-flow recursion like the dry-run markers. self.dynamic_params: dict[str, Any] = {} + # Decision-budget counter (issue #370): number of decision callbacks + # invoked so far in this flow execution. Compared against + # ``DecisionPolicy.max_decisions_per_flow``. A sub-flow runs in its + # own scope (a fresh ``_run_scope``), so it starts its budget at zero + # rather than inheriting the parent's count. + self.decision_count: int = 0 def copy(self) -> _RunScopedState: """Return a shallow per-scope clone (mutable containers duplicated). @@ -483,6 +494,14 @@ class StepRecord(BaseModel): (issue #356), the :class:`~chainweaver.approvals.ApprovalRecord` describing the decision. ``None`` for steps whose effective contract did not require approval (the common case). + decision: For a step resolved by a guided decision callback (issue + #369), the :class:`~chainweaver.decisions.DecisionRecord` capturing + the candidate set, the chosen tool, the static default, and the + callback latency. Populated **exactly when** a registered + :class:`~chainweaver.decisions.DecisionCallback` resolved the step; + ``None`` for ordinary steps and for ``decision_candidates`` steps + that fell back to the static ``tool_name`` because no callback was + registered. """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -507,6 +526,7 @@ class StepRecord(BaseModel): flow_name: str | None = None sub_result: ExecutionResult | None = None approval: ApprovalRecord | None = None + decision: DecisionRecord | None = None @model_validator(mode="after") def _fill_error_code(self) -> StepRecord: @@ -716,6 +736,7 @@ def __init__( checkpointer: Checkpointer | None = None, delete_on_success: bool = True, decision_callback: DecisionCallback | DecisionCallable | None = None, + decision_policy: DecisionPolicy | None = None, approval_callback: ApprovalCallback | ApprovalCallable | None = None, strict_safety: bool = False, max_side_effect_level: SideEffectLevel | None = None, @@ -760,6 +781,10 @@ def __init__( self._decision_callback: DecisionCallback | None = coerce_decision_callback( decision_callback ) + # Optional guardrails for the decision-callback seam (issue #370): + # a per-decision timeout and a per-flow decision budget. ``None`` + # (the default) leaves decision behavior byte-for-byte unchanged. + self._decision_policy: DecisionPolicy | None = decision_policy # Execution-time safety enforcement (issue #356). All opt-in: # ``approval_callback`` is the seam invoked before a step whose effective # ``ToolSafetyContract`` has ``requires_approval=True``; ``strict_safety`` @@ -1031,6 +1056,7 @@ def with_replaced_tools(self, tools: Iterable[Tool]) -> FlowExecutor: checkpointer=self._checkpointer, delete_on_success=self._delete_on_success, decision_callback=self._decision_callback, + decision_policy=self._decision_policy, max_composition_depth=self._max_composition_depth, max_step_concurrency=self._max_step_concurrency, ) @@ -3579,6 +3605,193 @@ def _execute_subflow_step( ) return record + def _call_decision_with_timeout( + self, + callback: DecisionCallback, + ctx: DecisionContext, + timeout_s: float, + ) -> tuple[str, bool]: + """Invoke ``callback.decide(ctx)`` with a bounded wall-clock timeout (#370). + + The callback runs on a daemon worker thread joined for at most + *timeout_s* seconds — mirroring the threading approach + :meth:`stream_flow` already uses — so the executor itself never adds + network or LLM behavior and the determinism invariants hold. + + Returns a ``(chosen, timed_out)`` pair. When the join times out, + ``timed_out`` is ``True`` and ``chosen`` is an empty placeholder the + caller ignores; the orphaned thread cannot be force-killed and its + late return is discarded. A callback exception is re-raised in the + calling thread so the normal :class:`DecisionCallbackError` path + handles it. + """ + result_holder: list[str] = [] + exc_holder: list[BaseException] = [] + + def _runner() -> None: + try: + result_holder.append(callback.decide(ctx)) + except BaseException as exc: # surfaced in the calling thread below + exc_holder.append(exc) + + thread = threading.Thread( + target=_runner, + name=f"chainweaver-decision-{ctx.flow_name}", + daemon=True, + ) + thread.start() + thread.join(timeout_s) + if thread.is_alive(): + return "", True + if exc_holder: + raise exc_holder[0] + return result_holder[0], False + + def _resolve_decision( + self, + step: FlowStep, + *, + step_index: int, + context: dict[str, Any], + flow_name: str, + trace_id: str, + step_id: str | None, + started_at: datetime, + perf_start: float, + ) -> tuple[FlowStep, DecisionRecord | None, StepRecord | None]: + """Resolve a guided decision point (#102), recording it for audit (#369). + + Returns ``(resolved_step, decision_record, failure_record)``: + + * ``resolved_step`` is *step* rebound to the chosen tool (or *step* + unchanged when no callback ran or the default was re-selected). + * ``decision_record`` is the :class:`DecisionRecord` to stamp on the + step's trace — populated **only** when a registered callback + actually resolved the step; ``None`` otherwise. + * ``failure_record`` is a terminal failed :class:`StepRecord` (already + emitted to ``on_step_end``) when the callback raised, returned an + out-of-set name, or timed out under ``on_timeout="error"``; the + caller returns it immediately. ``None`` on success. + + Raises: + DecisionBudgetExceededError: When a :class:`DecisionPolicy` budget + is configured and this decision would exceed it — aborting the + whole flow rather than a single step (issue #370). + """ + callback = self._decision_callback + if step.decision_candidates is None or callback is None: + return step, None, None + + default_name = step.display_name + candidates = list(step.decision_candidates) + policy = self._decision_policy + + # Per-flow decision budget (issue #370): count every decision attempt + # in this execution scope and abort the flow once the ceiling is passed. + state = self._local + state.decision_count += 1 + if ( + policy is not None + and policy.max_decisions_per_flow is not None + and state.decision_count > policy.max_decisions_per_flow + ): + raise DecisionBudgetExceededError(flow_name, policy.max_decisions_per_flow) + + def _failed(err: ChainWeaverError) -> StepRecord: + log_step_error(_logger, step_index, default_name, err) + err_type, err_msg = _exc_to_strings(err) + record = StepRecord( + step_index=step_index, + tool_name=default_name, + inputs={}, + outputs=None, + error_type=err_type, + error_message=err_msg, + success=False, + started_at=started_at, + ended_at=_now_utc(), + duration_ms=(time.perf_counter() - perf_start) * 1000.0, + ) + self._fire_step_end( + StepEndContext(trace_id=trace_id, flow_name=flow_name, step_record=record) + ) + return record + + ctx = DecisionContext( + trace_id=trace_id, + flow_name=flow_name, + step_index=step_index, + step_id=step_id, + default_tool_name=default_name, + candidates=candidates, + context=dict(context), + ) + + decide_t0 = time.perf_counter() + timed_out = False + try: + if policy is not None and policy.timeout_s is not None: + chosen, timed_out = self._call_decision_with_timeout( + callback, ctx, policy.timeout_s + ) + else: + chosen = callback.decide(ctx) + except Exception as exc: + err = DecisionCallbackError( + default_name, + step_index, + f"callback raised {type(exc).__name__}: {exc}", + ) + err.__cause__ = exc + return step, None, _failed(err) + decide_ms = (time.perf_counter() - decide_t0) * 1000.0 + + if timed_out: + # ``policy`` is non-None here (timeout only fires when set). + assert policy is not None and policy.timeout_s is not None + if policy.on_timeout == "error": + return ( + step, + None, + _failed(DecisionTimeoutError(default_name, step_index, policy.timeout_s)), + ) + # ``on_timeout="default"``: route around the (possibly still-running) + # callback to the step's static tool and mark the fallback. + decision = DecisionRecord( + candidates=candidates, + chosen=default_name, + default_tool_name=default_name, + duration_ms=decide_ms, + timed_out=True, + ) + return step, decision, None + + if chosen not in step.decision_candidates: + return ( + step, + None, + _failed( + DecisionCallbackError( + default_name, + step_index, + f"callback returned '{chosen}' which is not in " + f"decision_candidates={candidates!r}", + ) + ), + ) + + decision = DecisionRecord( + candidates=candidates, + chosen=chosen, + default_tool_name=default_name, + duration_ms=decide_ms, + timed_out=False, + ) + resolved = ( + step if chosen == default_name else step.model_copy(update={"tool_name": chosen}) + ) + return resolved, decision, None + def _execute_step( self, step_index: int, @@ -3635,82 +3848,28 @@ def _execute_step( deadline=deadline, cancel_token=cancel_token, ) - # Resolve guided decision points (issue #102). When the step - # declares ``decision_candidates`` *and* the executor has a - # ``decision_callback`` registered, ask the callback which - # candidate to invoke and rebind the step to that tool for the - # remainder of this call. No callback registered → fall back to - # the static ``tool_name`` so flows stay runnable without the - # integration. Callback failures fail the step early via - # ``DecisionCallbackError`` — silent fall-through would mask - # configuration bugs. - if step.decision_candidates is not None and self._decision_callback is not None: - try: - chosen = self._decision_callback.decide( - DecisionContext( - trace_id=trace_id, - flow_name=flow_name, - step_index=step_index, - step_id=step_id, - default_tool_name=step.display_name, - candidates=list(step.decision_candidates), - context=dict(context), - ) - ) - except Exception as exc: - err = DecisionCallbackError( - step.display_name, - step_index, - f"callback raised {type(exc).__name__}: {exc}", - ) - err.__cause__ = exc - log_step_error(_logger, step_index, step.display_name, err) - err_type, err_msg = _exc_to_strings(err) - now = _now_utc() - record = StepRecord( - step_index=step_index, - tool_name=step.display_name, - inputs={}, - outputs=None, - error_type=err_type, - error_message=err_msg, - success=False, - started_at=started_at, - ended_at=now, - duration_ms=(time.perf_counter() - t0) * 1000.0, - ) - self._fire_step_end( - StepEndContext(trace_id=trace_id, flow_name=flow_name, step_record=record) - ) - return record - if chosen not in step.decision_candidates: - err = DecisionCallbackError( - step.display_name, - step_index, - f"callback returned '{chosen}' which is not in " - f"decision_candidates={list(step.decision_candidates)!r}", - ) - log_step_error(_logger, step_index, step.display_name, err) - err_type, err_msg = _exc_to_strings(err) - now = _now_utc() - record = StepRecord( - step_index=step_index, - tool_name=step.display_name, - inputs={}, - outputs=None, - error_type=err_type, - error_message=err_msg, - success=False, - started_at=started_at, - ended_at=now, - duration_ms=(time.perf_counter() - t0) * 1000.0, - ) - self._fire_step_end( - StepEndContext(trace_id=trace_id, flow_name=flow_name, step_record=record) - ) - return record - if chosen != step.display_name: - step = step.model_copy(update={"tool_name": chosen}) + # Resolve guided decision points (issue #102). When the step declares + # ``decision_candidates`` *and* the executor has a ``decision_callback`` + # registered, ask the callback which candidate to invoke and rebind the + # step to that tool for the remainder of this call. No callback + # registered → fall back to the static ``tool_name`` so flows stay + # runnable without the integration. ``_resolve_decision`` also records + # the choice for the audit trail (issue #369) and enforces any + # ``DecisionPolicy`` timeout / budget (issue #370); callback failures + # fail the step early via a returned record — silent fall-through would + # mask configuration bugs. + step, decision_record, decision_failure = self._resolve_decision( + step, + step_index=step_index, + context=context, + flow_name=flow_name, + trace_id=trace_id, + step_id=step_id, + started_at=started_at, + perf_start=t0, + ) + if decision_failure is not None: + return decision_failure # Mutable holder so ``_invoke_tool`` can report how many times the # primary tool was actually called. Threading this through (instead # of deriving from ``len(retry_errors)``) keeps ``retry_count`` @@ -3758,6 +3917,7 @@ def _record( fallback_used=fallback_used, fallback_tool_name=fallback_tool_name, approval=approval_record, + decision=decision_record, ) def _finish(record: StepRecord) -> StepRecord: diff --git a/chainweaver/flow.py b/chainweaver/flow.py index 69ed595..96511c1 100644 --- a/chainweaver/flow.py +++ b/chainweaver/flow.py @@ -657,8 +657,14 @@ def determinism_level(self) -> DeterminismLevel: Linear :class:`Flow` instances are :class:`DeterminismLevel.FULL` by definition — every step always runs, in declared order — *unless* - the flow author explicitly opts out by setting ``deterministic=False``, - in which case the level is :class:`DeterminismLevel.NONE`. + the flow author explicitly opts out by setting ``deterministic=False`` + (which yields :class:`DeterminismLevel.NONE`), or a step carries a + non-empty :attr:`FlowStep.decision_candidates` list, which downgrades + the flow to :class:`DeterminismLevel.PARTIAL` (issue #369). Guided + decision points (#102) let a registered callback pick which candidate + tool runs, so the executed path is data-dependent at runtime even + though the step sequence is fixed — the same reason :class:`DAGFlow` + downgrades for ``branches``. This property reflects flow *structure* only. Tool-level safety contracts are not consulted here because the flow does not have @@ -669,6 +675,8 @@ def determinism_level(self) -> DeterminismLevel: """ if not self.deterministic: return DeterminismLevel.NONE + if any(step.decision_candidates for step in self.steps): + return DeterminismLevel.PARTIAL return DeterminismLevel.FULL @property @@ -1009,13 +1017,16 @@ def determinism_level(self) -> DeterminismLevel: :class:`DeterminismLevel.PARTIAL` whenever **any** step carries a non-empty :attr:`DAGFlowStep.branches` list — branches make the executed path data-dependent at runtime, even though the graph - itself is fixed. A DAG with no branches is - :class:`DeterminismLevel.FULL`, and any flow that explicitly opts - out via ``deterministic=False`` is :class:`DeterminismLevel.NONE`. + itself is fixed — or a non-empty + :attr:`DAGFlowStep.decision_candidates` list, where a registered + decision callback picks the tool at runtime (issue #369). A DAG with + neither is :class:`DeterminismLevel.FULL`, and any flow that + explicitly opts out via ``deterministic=False`` is + :class:`DeterminismLevel.NONE`. """ if not self.deterministic: return DeterminismLevel.NONE - if any(step.branches for step in self.steps): + if any(step.branches or step.decision_candidates for step in self.steps): return DeterminismLevel.PARTIAL return DeterminismLevel.FULL diff --git a/tests/fixtures/public_api.json b/tests/fixtures/public_api.json index 180b28f..b56565c 100644 --- a/tests/fixtures/public_api.json +++ b/tests/fixtures/public_api.json @@ -41,10 +41,14 @@ "DAGDefinitionError", "DAGFlow", "DAGFlowStep", + "DecisionBudgetExceededError", "DecisionCallable", "DecisionCallback", "DecisionCallbackError", "DecisionContext", + "DecisionPolicy", + "DecisionRecord", + "DecisionTimeoutError", "DeterminismLevel", "DraftFlow", "DriftInfo", @@ -581,6 +585,12 @@ "module": "chainweaver.flow", "qualname": "DAGFlowStep" }, + "DecisionBudgetExceededError": { + "kind": "class", + "module": "chainweaver.exceptions", + "qualname": "DecisionBudgetExceededError", + "signature": "(flow_name: str, budget: int) -> None" + }, "DecisionCallable": { "kind": "_CallableGenericAlias", "module": "collections.abc", @@ -612,6 +622,34 @@ "module": "chainweaver.decisions", "qualname": "DecisionContext" }, + "DecisionPolicy": { + "kind": "pydantic-model", + "model_fields": { + "max_decisions_per_flow": "int | NoneType", + "on_timeout": "Literal['error', 'default']", + "timeout_s": "float | NoneType" + }, + "module": "chainweaver.decisions", + "qualname": "DecisionPolicy" + }, + "DecisionRecord": { + "kind": "pydantic-model", + "model_fields": { + "candidates": "list[str]", + "chosen": "str", + "default_tool_name": "str", + "duration_ms": "float", + "timed_out": "bool" + }, + "module": "chainweaver.decisions", + "qualname": "DecisionRecord" + }, + "DecisionTimeoutError": { + "kind": "class", + "module": "chainweaver.exceptions", + "qualname": "DecisionTimeoutError", + "signature": "(tool_name: str, step_index: int, timeout_s: float) -> None" + }, "DeterminismLevel": { "kind": "enum", "module": "chainweaver.contracts", @@ -823,7 +861,7 @@ "kind": "class", "module": "chainweaver.executor", "qualname": "FlowExecutor", - "signature": "(registry: FlowRegistry, *, cost_profile: CostProfile | None = None, redaction_policy: RedactionPolicy | None = None, trace_recorder: TraceRecorder | None = None, middleware: list[FlowExecutorMiddleware] | None = None, step_cache: StepCache | None = None, checkpointer: Checkpointer | None = None, delete_on_success: bool = True, decision_callback: DecisionCallback | DecisionCallable | None = None, approval_callback: ApprovalCallback | ApprovalCallable | None = None, strict_safety: bool = False, max_side_effect_level: SideEffectLevel | None = None, discover_plugins: bool = False, max_composition_depth: int = 10, max_step_concurrency: int = 1) -> None" + "signature": "(registry: FlowRegistry, *, cost_profile: CostProfile | None = None, redaction_policy: RedactionPolicy | None = None, trace_recorder: TraceRecorder | None = None, middleware: list[FlowExecutorMiddleware] | None = None, step_cache: StepCache | None = None, checkpointer: Checkpointer | None = None, delete_on_success: bool = True, decision_callback: DecisionCallback | DecisionCallable | None = None, decision_policy: DecisionPolicy | None = None, approval_callback: ApprovalCallback | ApprovalCallable | None = None, strict_safety: bool = False, max_side_effect_level: SideEffectLevel | None = None, discover_plugins: bool = False, max_composition_depth: int = 10, max_step_concurrency: int = 1) -> None" }, "FlowExecutorMiddleware": { "kind": "class", @@ -1448,6 +1486,7 @@ "model_fields": { "approval": "chainweaver.approvals.ApprovalRecord | NoneType", "cached": "bool", + "decision": "chainweaver.decisions.DecisionRecord | NoneType", "duration_ms": "float", "ended_at": "datetime.datetime", "error_code": "str | NoneType", diff --git a/tests/test_decision_audit_policy.py b/tests/test_decision_audit_policy.py new file mode 100644 index 0000000..cf3e030 --- /dev/null +++ b/tests/test_decision_audit_policy.py @@ -0,0 +1,351 @@ +"""Tests for decision-callback audit records and policy controls. + +Covers the determinism-level downgrade and ``StepRecord.decision`` audit +trail (issue #369) and the ``DecisionPolicy`` timeout / budget guardrails +(issue #370). +""" + +from __future__ import annotations + +import time + +import pytest +from helpers import NumberInput, ValueInput, ValueOutput, _add_ten_fn, _double_fn + +from chainweaver.contracts import DeterminismLevel +from chainweaver.decisions import ( + DecisionCallable, + DecisionContext, + DecisionPolicy, + DecisionRecord, +) +from chainweaver.exceptions import DecisionBudgetExceededError +from chainweaver.executor import ExecutionResult, FlowExecutor +from chainweaver.flow import DAGFlow, DAGFlowStep, Flow, FlowStep +from chainweaver.registry import FlowRegistry +from chainweaver.tools import Tool + + +def _build_two_tools() -> tuple[Tool, Tool]: + return ( + Tool( + name="double", + description="Doubles the input.", + input_schema=NumberInput, + output_schema=ValueOutput, + fn=_double_fn, + ), + Tool( + name="add_ten", + description="Adds ten to the input.", + input_schema=ValueInput, + output_schema=ValueOutput, + fn=_add_ten_fn, + ), + ) + + +def _one_decision_flow(name: str = "picky") -> Flow: + return Flow( + name=name, + version="0.1.0", + description="Single decision step, default double.", + steps=[ + FlowStep( + tool_name="double", + input_mapping={"number": "number"}, + decision_candidates=["double", "add_ten"], + ), + ], + ) + + +def _executor( + callback: DecisionCallable | None = None, policy: DecisionPolicy | None = None +) -> FlowExecutor: + reg = FlowRegistry() + reg.register_flow(_one_decision_flow()) + ex = FlowExecutor(registry=reg, decision_callback=callback, decision_policy=policy) + for tool in _build_two_tools(): + ex.register_tool(tool) + return ex + + +# -------------------------------------------------------------------------- +# #369 — determinism level matrix +# -------------------------------------------------------------------------- + + +def test_linear_flow_without_candidates_is_full() -> None: + flow = Flow( + name="plain", + version="0.1.0", + description="No decision points.", + steps=[FlowStep(tool_name="double")], + ) + assert flow.determinism_level is DeterminismLevel.FULL + + +def test_linear_flow_with_candidates_is_partial() -> None: + assert _one_decision_flow().determinism_level is DeterminismLevel.PARTIAL + + +def test_linear_flow_with_candidates_and_nondeterministic_is_none() -> None: + flow = Flow( + name="picky", + version="0.1.0", + description="Opted out of determinism.", + deterministic=False, + steps=[FlowStep(tool_name="double", decision_candidates=["double", "add_ten"])], + ) + assert flow.determinism_level is DeterminismLevel.NONE + + +def test_dag_flow_with_candidates_is_partial() -> None: + dag = DAGFlow( + name="dag_picky", + version="0.1.0", + description="DAG with a decision step.", + steps=[ + DAGFlowStep( + tool_name="double", + step_id="pick", + decision_candidates=["double", "add_ten"], + ), + ], + ) + assert dag.determinism_level is DeterminismLevel.PARTIAL + + +def test_dag_flow_without_candidates_is_full() -> None: + dag = DAGFlow( + name="dag_plain", + version="0.1.0", + description="DAG with no decision step.", + steps=[DAGFlowStep(tool_name="double", step_id="only")], + ) + assert dag.determinism_level is DeterminismLevel.FULL + + +# -------------------------------------------------------------------------- +# #369 — StepRecord.decision audit trail +# -------------------------------------------------------------------------- + + +def test_decision_record_populated_on_callback_success() -> None: + ex = _executor(callback=lambda ctx: "double") + result = ex.execute_flow("picky", {"number": 5}) + assert result.success is True + decision = result.execution_log[0].decision + assert isinstance(decision, DecisionRecord) + assert decision.candidates == ["double", "add_ten"] + assert decision.chosen == "double" + assert decision.default_tool_name == "double" + assert decision.timed_out is False + assert decision.duration_ms >= 0.0 + + +def test_decision_record_captures_overridden_choice() -> None: + flow = Flow( + name="picky", + version="0.1.0", + description="Default double, callback picks add_ten.", + steps=[ + FlowStep( + tool_name="double", + input_mapping={"value": "number"}, + decision_candidates=["double", "add_ten"], + ), + ], + ) + reg = FlowRegistry() + reg.register_flow(flow) + ex = FlowExecutor(registry=reg, decision_callback=lambda ctx: "add_ten") + for tool in _build_two_tools(): + ex.register_tool(tool) + result = ex.execute_flow("picky", {"number": 5}) + rec = result.execution_log[0] + assert rec.tool_name == "add_ten" + assert rec.decision is not None + assert rec.decision.chosen == "add_ten" + assert rec.decision.default_tool_name == "double" + + +def test_decision_record_absent_on_static_fallback() -> None: + # decision_candidates set but no callback registered → static fallback, + # no decision recorded (the callback never ran). + ex = _executor(callback=None) + result = ex.execute_flow("picky", {"number": 5}) + assert result.success is True + assert result.execution_log[0].decision is None + + +def test_decision_record_absent_without_candidates() -> None: + flow = Flow( + name="plain", + version="0.1.0", + description="No candidates.", + steps=[FlowStep(tool_name="double", input_mapping={"number": "number"})], + ) + reg = FlowRegistry() + reg.register_flow(flow) + ex = FlowExecutor(registry=reg, decision_callback=lambda ctx: "double") + for tool in _build_two_tools(): + ex.register_tool(tool) + result = ex.execute_flow("plain", {"number": 5}) + assert result.execution_log[0].decision is None + + +def test_decision_record_round_trips_through_json() -> None: + ex = _executor(callback=lambda ctx: "double") + result = ex.execute_flow("picky", {"number": 5}) + restored = ExecutionResult.model_validate_json(result.model_dump_json()) + decision = restored.execution_log[0].decision + assert decision is not None + assert decision.chosen == "double" + assert decision.candidates == ["double", "add_ten"] + + +# -------------------------------------------------------------------------- +# #370 — DecisionPolicy validation +# -------------------------------------------------------------------------- + + +def test_decision_policy_rejects_non_positive_timeout() -> None: + with pytest.raises(ValueError, match="timeout_s must be positive"): + DecisionPolicy(timeout_s=0) + + +def test_decision_policy_rejects_zero_budget() -> None: + with pytest.raises(ValueError, match="max_decisions_per_flow must be >= 1"): + DecisionPolicy(max_decisions_per_flow=0) + + +def test_decision_policy_defaults_to_error_on_timeout() -> None: + assert DecisionPolicy(timeout_s=1.0).on_timeout == "error" + + +# -------------------------------------------------------------------------- +# #370 — timeout behavior +# -------------------------------------------------------------------------- + + +def _slow_callback(ctx: DecisionContext) -> str: + time.sleep(0.3) + return ctx.default_tool_name + + +def test_timeout_error_fails_the_step() -> None: + ex = _executor( + callback=_slow_callback, + policy=DecisionPolicy(timeout_s=0.05, on_timeout="error"), + ) + result = ex.execute_flow("picky", {"number": 5}) + assert result.success is False + rec = result.execution_log[0] + assert rec.error_type == "DecisionTimeoutError" + assert rec.error_code == "CW-E049" + + +def test_timeout_default_falls_back_to_static_tool() -> None: + ex = _executor( + callback=_slow_callback, + policy=DecisionPolicy(timeout_s=0.05, on_timeout="default"), + ) + result = ex.execute_flow("picky", {"number": 5}) + assert result.success is True + rec = result.execution_log[0] + assert rec.tool_name == "double" + assert rec.outputs == {"value": 10} + assert rec.decision is not None + assert rec.decision.timed_out is True + assert rec.decision.chosen == "double" + + +def test_timed_out_callback_does_not_corrupt_later_runs() -> None: + ex = _executor( + callback=_slow_callback, + policy=DecisionPolicy(timeout_s=0.05, on_timeout="default"), + ) + first = ex.execute_flow("picky", {"number": 5}) + assert first.success is True + # A second run on the same executor is unaffected by the orphaned thread. + second = ex.execute_flow("picky", {"number": 7}) + assert second.success is True + assert second.execution_log[0].outputs == {"value": 14} + + +# -------------------------------------------------------------------------- +# #370 — decision budget +# -------------------------------------------------------------------------- + + +def _two_decision_flow() -> Flow: + return Flow( + name="twopick", + version="0.1.0", + description="Two decision steps.", + on_context_collision="overwrite", + steps=[ + FlowStep( + tool_name="double", + input_mapping={"number": "number"}, + decision_candidates=["double", "add_ten"], + ), + FlowStep( + tool_name="double", + input_mapping={"number": "number"}, + decision_candidates=["double", "add_ten"], + ), + ], + ) + + +def _two_decision_executor(policy: DecisionPolicy | None) -> FlowExecutor: + reg = FlowRegistry() + reg.register_flow(_two_decision_flow()) + ex = FlowExecutor( + registry=reg, + decision_callback=lambda ctx: "double", + decision_policy=policy, + ) + for tool in _build_two_tools(): + ex.register_tool(tool) + return ex + + +def test_budget_exhaustion_aborts_the_flow() -> None: + ex = _two_decision_executor(DecisionPolicy(max_decisions_per_flow=1)) + with pytest.raises(DecisionBudgetExceededError) as excinfo: + ex.execute_flow("twopick", {"number": 5}) + assert excinfo.value.flow_name == "twopick" + assert excinfo.value.budget == 1 + + +def test_budget_within_limit_runs_normally() -> None: + ex = _two_decision_executor(DecisionPolicy(max_decisions_per_flow=2)) + result = ex.execute_flow("twopick", {"number": 5}) + assert result.success is True + assert len(result.execution_log) == 2 + + +def test_budget_resets_per_flow_execution() -> None: + # A budget of 2 lets a two-decision flow run, and the counter resets on the + # next execution rather than accumulating across runs. + ex = _two_decision_executor(DecisionPolicy(max_decisions_per_flow=2)) + assert ex.execute_flow("twopick", {"number": 5}).success is True + assert ex.execute_flow("twopick", {"number": 6}).success is True + + +# -------------------------------------------------------------------------- +# #370 — policy-absent regression +# -------------------------------------------------------------------------- + + +def test_no_policy_leaves_behavior_unchanged() -> None: + ex = _executor(callback=lambda ctx: "double", policy=None) + result = ex.execute_flow("picky", {"number": 5}) + assert result.success is True + assert result.execution_log[0].outputs == {"value": 10} + # The audit record is still written (that is #369, independent of policy). + assert result.execution_log[0].decision is not None From 948ea5a2498e230dc58fe59ce1e7e03a0d3f216d Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 22 Jun 2026 20:25:15 +0000 Subject: [PATCH 2/6] =?UTF-8?q?feat(executor):=20async-lane=20parity=20?= =?UTF-8?q?=E2=80=94=20cache,=20checkpoint=20resume,=20sub-flows=20(#388)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit execute_flow_async now consults the step cache (cached=True on hits, skips Tool.run_async), writes crash-resume checkpoints after each successful step / DAG level, and executes composed flow_name sub-flow steps with sub_result and deadline/cancel_token forwarding. Add resume_flow_async(trace_id) mirroring resume_flow (incl. CheckpointDriftError). Narrow the async-lane guard to keep rejecting only branching (#9) and decision callbacks (#102). Sync lane unchanged. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_011U2ZKg2GMec7iAUDt2zezq --- AGENTS.md | 8 +- CHANGELOG.md | 12 + chainweaver/executor.py | 581 +++++++++++++++++++++++++--- docs/reference/error-table.md | 2 + tests/test_composition.py | 15 +- tests/test_executor_async.py | 55 ++- tests/test_executor_async_parity.py | 274 +++++++++++++ 7 files changed, 861 insertions(+), 86 deletions(-) create mode 100644 tests/test_executor_async_parity.py diff --git a/AGENTS.md b/AGENTS.md index 743e63c..cf00c9e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -150,10 +150,10 @@ benchmarks/ Standalone benchmark scripts (not coverage-gated): - `flow_to_selectable_item(flow, *, capability_id=None, tags=())` from `chainweaver.integrations.weaver_spec` (#107) → project a `Flow` or `DAGFlow` to a weaver-spec `SelectableItem` for contextweaver catalog ingestion. - `RoutingDecisionAdapter(client=...)` from `chainweaver.integrations.contextweaver` (#106) → `DecisionCallback` impl that asks a `ContextweaverClient` for a `RoutingDecision` and returns the selected capability id. - `FlowExecutor.execute_flow(flow_name, initial_input, *, version=None, force=False, deadline=None, cancel_token=None)` → `ExecutionResult`. `version` (#201) targets an exact registered flow version (default: latest); the version that ran is recorded on `ExecutionResult.flow_version`. `deadline` (wall-clock `time.time()` seconds) and `cancel_token` (`CancellationToken`, #142) cooperatively cancel **between** steps / DAG levels — never inside a tool — raising `FlowCancelledError` with the partial result. -- `FlowExecutor.execute_flow_async(flow_name, initial_input, *, version=None, force=False, deadline=None, cancel_token=None)` → `Awaitable[ExecutionResult]` (#80); async-native counterpart of `execute_flow`. Dispatches each step through `Tool.run_async` so async-fn tools (e.g. those produced by `chainweaver.mcp.MCPToolAdapter`) execute on the calling loop and sync-fn tools are offloaded to `asyncio.to_thread`. Supports linear and DAG flows with retries, middleware, and on_error policies; honours `version` / `deadline` / `cancel_token`; rejects composed `flow_name` sub-flow steps (#75, sync-only); defers step cache + checkpoint resume to a follow-up. +- `FlowExecutor.execute_flow_async(flow_name, initial_input, *, version=None, force=False, deadline=None, cancel_token=None)` → `Awaitable[ExecutionResult]` (#80); async-native counterpart of `execute_flow`. Dispatches each step through `Tool.run_async` so async-fn tools (e.g. those produced by `chainweaver.mcp.MCPToolAdapter`) execute on the calling loop and sync-fn tools are offloaded to `asyncio.to_thread`. Supports linear and DAG flows with retries, middleware, and on_error policies; honours `version` / `deadline` / `cancel_token`; executes composed `flow_name` sub-flow steps, consults the step cache, and writes checkpoints — resume via `resume_flow_async(trace_id)` (#388). Still rejects conditional branching (#9) and `decision_candidates` (#102). - `FlowExecutor.stream_flow(flow_name, initial_input, *, force=False)` → `Iterator[FlowEvent]` (#134); yields `kind="flow_start"` → (`step_start` → `step_end`)* → `flow_end` events as the flow runs on a worker thread. Cancellation is not supported for the sync variant; the background thread runs to completion. - `FlowExecutor(..., step_cache=...)` → memoize step outputs across runs (#127); keyed by `(tool_name, schema_hash, input_value_hash)`. Cache hits skip `Tool.fn` entirely (including retries and timeout) and surface as `StepRecord.cached=True`. Tools mark themselves `cacheable=False` to always run (side-effects, external state). `replay_flow` always bypasses the cache. -- `FlowExecutor(..., checkpointer=..., delete_on_success=True)` → crash-resume (#128); writes an `ExecutionSnapshot` after every successful linear step or DAG level. `FlowExecutor.resume_flow(trace_id)` validates the snapshot's flow version and tool `schema_hash` values against the current registry — drift raises `CheckpointDriftError` — then continues execution with the original `trace_id`. Snapshots are deleted on terminal success when `delete_on_success=True` (the default); preserved on failure for operator-driven retry. +- `FlowExecutor(..., checkpointer=..., delete_on_success=True)` → crash-resume (#128); writes an `ExecutionSnapshot` after every successful linear step or DAG level. `FlowExecutor.resume_flow(trace_id)` (or `resume_flow_async(trace_id)` for runs started on the async lane, #388) validates the snapshot's flow version and tool `schema_hash` values against the current registry — drift raises `CheckpointDriftError` — then continues execution with the original `trace_id`. Snapshots are deleted on terminal success when `delete_on_success=True` (the default); preserved on failure for operator-driven retry. - `OTelTraceExporter(tracer=...)` from `chainweaver.integrations.opentelemetry` (#126) → emits OpenTelemetry spans as a `FlowExecutorMiddleware`: one parent `chainweaver.flow.{name}` span + one child `chainweaver.tool.{name}` span per `StepRecord`. After-the-fact export of a completed `ExecutionResult` via `export_result_to_otel(result, tracer=...)`. Optional extra: `pip install 'chainweaver[otel]'`. - `MCPToolAdapter(session)` from `chainweaver.mcp` (#70, #150) → wraps each MCP tool advertised by an open `mcp.ClientSession` as a ChainWeaver `Tool`. `await adapter.discover_tools(server_prefix="…")` returns the wrapped tools; pass `include=[…]` to filter. The resulting tools are async-fn and must be run through `execute_flow_async`. Optional extra: `pip install 'chainweaver[mcp]'`. - `FlowServer(executor, *, name="chainweaver", flow_names=None, server_prefix="")` from `chainweaver.mcp` (#72) → mounts registered flows as MCP tools on a FastMCP server. `server.serve(transport="stdio")` blocks; `await server.serve_async(transport=...)` returns to the loop. Synthesises the dispatcher signature from the flow's input schema so MCP clients call `tool(n=5)` directly. Optional extra: `pip install 'chainweaver[mcp]'`. @@ -261,8 +261,8 @@ for features it does not yet honour, rather than diverging silently: | Opt-in DAG-level concurrency (#344) | sequential | ✅ (`max_step_concurrency`) | | Conditional branches / `default_next` (#9) | ✅ | ❌ rejected | | `decision_candidates` (#102) | ✅ | ❌ rejected | -| Composed sub-flow (`flow_name`, #75) | ✅ | ❌ rejected | -| Step cache / checkpoint resume | ✅ | bypassed | +| Composed sub-flow (`flow_name`, #75) | ✅ | ✅ (#388) | +| Step cache / checkpoint resume | ✅ | ✅ (#388; resume via `resume_flow_async`) | ### State transitions (#335) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4aa6537..6517312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Async-lane parity: cache, checkpoint resume, sub-flow composition** + (#388): `execute_flow_async` now consults the step cache (records + `StepRecord.cached=True` on hits and skips `Tool.run_async`), writes + crash-resume checkpoints after each successful step / DAG level, and + executes composed `flow_name` sub-flow steps (with `sub_result` populated + and `deadline` / `cancel_token` forwarded into the sub-flow). A new + `FlowExecutor.resume_flow_async(trace_id)` mirrors `resume_flow` — + including `CheckpointDriftError` on schema drift — on the async lane. The + async lane still rejects conditional branching (#9) and decision callbacks + (#102) up front via `AsyncLaneUnsupportedError`. Sync-lane behavior is + unchanged. + - **Decision-callback audit trail and policy controls** (#369, #370): guided decision points (#102) are now visible in traces and bounded by opt-in guardrails. diff --git a/chainweaver/executor.py b/chainweaver/executor.py index 3370237..4efe8c3 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -1824,15 +1824,14 @@ async def execute_flow_async( Linear and DAG flows are both supported. The async lane preserves middleware, retries, the ``on_error`` policy, and - flow-level input/output validation. It deliberately bypasses - the step cache and crash-resume checkpointer for v0.1 — those - features are async-unaware today and will be folded into the - async lane in a follow-up. + flow-level input/output validation. As of issue #388 it also + consults the step cache (#127), writes crash-resume checkpoints + (resume via :meth:`resume_flow_async`, #128), and executes composed + sub-flow steps (``flow_name``, #75) — at parity with the sync lane. The async lane does **not** yet honour conditional branching - (``branches`` / ``default_next``, #9), decision callbacks - (``decision_candidates``, #102), or composed sub-flow steps - (``flow_name``, #75). Flows declaring those features raise + (``branches`` / ``default_next``, #9) or decision callbacks + (``decision_candidates``, #102). Flows declaring those features raise :class:`AsyncLaneUnsupportedError` up front — listing every unsupported construct — rather than executing with the directives silently dropped (issue #332); use the synchronous @@ -1862,8 +1861,9 @@ async def execute_flow_async( Raises: AsyncLaneUnsupportedError: When the flow uses conditional - branching, decision callbacks, or composed sub-flow steps, - which the async lane does not yet support (issue #332). + branching or decision callbacks, which the async lane does not + yet support (issue #332). Composed sub-flow steps (#75) are + supported as of issue #388. FlowCancelledError: When *deadline* has passed or *cancel_token* is cancelled at a step boundary. """ @@ -1873,6 +1873,10 @@ async def execute_flow_async( raise FlowStatusError(flow_name, flow.status.value) self._assert_async_lane_supported(flow) + # Validate sub-flow composition up front (issue #75 / #388): reject + # cycles, over-deep nesting, and dangling references before any step + # runs — mirroring the synchronous lane. + self._validate_composition(flow) # A scoped copy per call is what makes concurrent ``execute_flow_async`` # tasks on one event-loop thread safe: each task mutates its own @@ -1893,25 +1897,19 @@ async def execute_flow_async( def _assert_async_lane_supported(flow: Any) -> None: """Reject flows using execution features the async lane can't honour. - ``execute_flow_async`` is a v0.1 lane (issue #80). It does not - yet implement the conditional-branching (#9), decision-callback - (#102), or composed sub-flow (#75) semantics the synchronous - :meth:`execute_flow` supports. The async DAG path builds a plain - tool proxy per step, so those directives would be **silently - dropped** — producing a different result than the sync lane for the - same flow. This collects *every* unsupported construct in the flow - and raises a single :class:`AsyncLaneUnsupportedError` **before any - step runs** (issue #332), so callers see the full set of reasons at - once and route such flows through :meth:`execute_flow` until the - async lane gains parity. + The async lane now supports composed sub-flow steps (#75), the step + cache, and checkpoint resume (issue #388), but still does not + implement conditional branching (``branches`` / ``default_next``, #9) + or guided decision callbacks (``decision_candidates``, #102). Flows + declaring those would have the directives **silently dropped** — + producing a different result than the sync lane — so this collects + *every* still-unsupported construct and raises a single + :class:`AsyncLaneUnsupportedError` **before any step runs** (issue + #332). Route such flows through :meth:`execute_flow` until the async + lane gains full parity. """ unsupported: list[str] = [] for idx, step in enumerate(flow.steps): - if getattr(step, "flow_name", None) is not None: - unsupported.append( - f"step {idx} ('{step.flow_name}'): composed sub-flow " - "(flow_name) steps (issue #75)" - ) if getattr(step, "decision_candidates", None): unsupported.append( f"step {idx} ('{step.display_name}'): decision_candidates (issue #102)" @@ -1998,7 +1996,31 @@ async def _execute_linear_flow_async( perf_start=flow_t0, initial_input=initial_input, ) - record = await self._execute_step_async(idx, step, context, flow_name, trace_id) + try: + record = await self._execute_step_async( + idx, + step, + context, + flow_name, + trace_id, + deadline=deadline, + cancel_token=cancel_token, + ) + except FlowCancelledError as exc: + # Cancellation fired inside a composed sub-flow; re-anchor it to + # this parent flow so the parent's partial carries its completed + # steps — mirroring the sync lane (issue #388). + self._reraise_subflow_cancellation( + exc, + parent_flow_name=flow_name, + step=step, + flat_step_index=idx, + prior_log=log, + trace_id=trace_id, + started_at=flow_started_at, + perf_start=flow_t0, + initial_input=initial_input, + ) log.append(record) if not record.success: return self._make_result( @@ -2026,6 +2048,19 @@ async def _execute_linear_flow_async( output_mapping=step.output_mapping, ) + # Crash-resume checkpoint (issue #128 / #388) — write after every + # successful step so a fresh process can resume via + # ``resume_flow_async``. + self._save_linear_snapshot( + trace_id=trace_id, + flow=flow, + initial_input=initial_input, + started_at=flow_started_at, + context=context, + log=log, + completed_steps=idx + 1, + ) + if flow.output_schema is not None: validation_record = self._validate_flow_schema( flow_name=flow_name, @@ -2075,8 +2110,19 @@ async def _execute_dag_flow_async( Cancellation (issue #142) is checked between levels. """ self._local.active_flow_version = flow.version - trace_id = _new_trace_id() - flow_started_at = _now_utc() + # Resume support (issue #128 / #388): when ``resume_snapshot`` is set, + # reuse its trace_id / started_at / context / log and skip the + # already-completed DAG levels. + resume = self._local.resume_snapshot + levels = self._compute_dag_levels(flow) + if resume is not None: + trace_id = resume.trace_id + flow_started_at = resume.started_at + start_level = resume.completed_dag_levels + else: + trace_id = _new_trace_id() + flow_started_at = _now_utc() + start_level = 0 flow_t0 = time.perf_counter() self._fire_flow_start( @@ -2090,7 +2136,7 @@ async def _execute_dag_flow_async( ) ) - if flow.input_schema is not None: + if resume is None and flow.input_schema is not None: validation_record = self._validate_flow_schema( flow_name=flow.name, payload=initial_input, @@ -2111,15 +2157,21 @@ async def _execute_dag_flow_async( tool_step_count=0, ) - # Seed the context from the validated initial input, then layer any - # dynamic params (#316) on top — server-supplied values win over an - # LLM-provided key of the same name and are never in input_schema. - context: dict[str, Any] = {**initial_input, **self._local.dynamic_params} - log: list[StepRecord] = [] - flat_index = 0 - levels = self._compute_dag_levels(flow) + if resume is not None: + context: dict[str, Any] = dict(resume.context) + log: list[StepRecord] = list(resume.execution_log) + # The flat step index continues past every step in the completed + # levels so resumed records keep their declaration-order indices. + flat_index = sum(len(level) for level in levels[:start_level]) + else: + # Seed the context from the validated initial input, then layer any + # dynamic params (#316) on top — server-supplied values win over an + # LLM-provided key of the same name and are never in input_schema. + context = {**initial_input, **self._local.dynamic_params} + log = [] + flat_index = 0 - for level_idx, level_steps in enumerate(levels): + for level_idx, level_steps in enumerate(levels[start_level:], start=start_level): # Cooperative cancellation between topological levels (issue #142). self._check_cancellation( flow_name=flow.name, @@ -2179,7 +2231,14 @@ async def _execute_dag_flow_async( # order in either case; ``context`` is read-only during the level # (outputs are merged only after it completes), so concurrent input # resolution is safe. - records = await self._run_dag_level_async(indexed_steps, context, flow.name, trace_id) + records = await self._run_dag_level_async( + indexed_steps, + context, + flow.name, + trace_id, + deadline=deadline, + cancel_token=cancel_token, + ) # Process results in declaration order: append to the log, abort on # the first failure, and reject sibling key collisions — identical @@ -2252,6 +2311,18 @@ async def _execute_dag_flow_async( logger=_logger, ) + # Crash-resume checkpoint at the level boundary (issue #128 / #388), + # mirroring the sync DAG lane's level-granularity snapshots. + self._save_dag_snapshot( + trace_id=trace_id, + flow=flow, + initial_input=initial_input, + started_at=flow_started_at, + context=context, + log=log, + completed_levels=level_idx + 1, + ) + if flow.output_schema is not None: validation_record = self._validate_flow_schema( flow_name=flow.name, @@ -2290,6 +2361,9 @@ async def _run_dag_level_async( context: dict[str, Any], flow_name: str, trace_id: str, + *, + deadline: float | None = None, + cancel_token: CancellationToken | None = None, ) -> list[StepRecord]: """Execute one DAG level's steps, returning records in declaration order. @@ -2304,14 +2378,32 @@ async def _run_dag_level_async( merged into it by the caller after the level completes — so concurrent execution introduces no shared-state writes on the executor's side. Opted-in tools must themselves be safe to run concurrently. + + The per-step proxy carries the step's ``flow_name`` (#75 / #388), + contracts, retry, and ``on_error`` so composed sub-flows and per-step + policies are honoured on the async lane just like the sync lane; + ``deadline`` / ``cancel_token`` flow into recursive sub-flow runs. """ async def _run_one(step_index: int, step: Any) -> StepRecord: proxy = FlowStep( - tool_name=step.display_name, + tool_name=step.tool_name, + flow_name=step.flow_name, input_mapping=step.input_mapping, + input_contract=step.input_contract, + output_contract=step.output_contract, + retry=step.retry, + on_error=step.on_error, + ) + return await self._execute_step_async( + step_index, + proxy, + context, + flow_name, + trace_id, + deadline=deadline, + cancel_token=cancel_token, ) - return await self._execute_step_async(step_index, proxy, context, flow_name, trace_id) if self._max_step_concurrency <= 1 or len(indexed_steps) <= 1: return [await _run_one(idx, step) for idx, step in indexed_steps] @@ -2333,6 +2425,9 @@ async def _execute_step_async( context: dict[str, Any], flow_name: str, trace_id: str, + *, + deadline: float | None = None, + cancel_token: CancellationToken | None = None, ) -> StepRecord: """Async-native counterpart to :meth:`_execute_step`. @@ -2344,10 +2439,28 @@ async def _execute_step_async( Retries and middleware hooks (which are sync APIs) are applied in the same order as the sync path; ``on_error`` fallback tools are also dispatched via ``run_async`` so MCP - fallbacks compose correctly. + fallbacks compose correctly. Composed sub-flow steps (#75) recurse + through :meth:`_execute_subflow_step_async`, and the step cache (#127) + is consulted before ``Tool.run_async`` — both mirroring the sync lane + (issue #388). ``deadline`` / ``cancel_token`` are forwarded into + recursive sub-flow runs so a composed run shares one budget. """ started_at = _now_utc() t0 = time.perf_counter() + # Composed sub-flow step (issue #75 / #388): recurse into the named + # flow on the async lane instead of invoking a tool. + if step.flow_name is not None: + return await self._execute_subflow_step_async( + step_index, + step, + context, + flow_name, + trace_id, + started_at=started_at, + perf_start=t0, + deadline=deadline, + cancel_token=cancel_token, + ) tool_attempts = [0] # Approval audit record (issue #356); set by the safety gate below. approval_record: ApprovalRecord | None = None @@ -2360,6 +2473,7 @@ def _record( success: bool, skipped: bool, retry_errors: list[str], + cached: bool = False, fallback_used: bool = False, fallback_tool_name: str | None = None, ) -> StepRecord: @@ -2379,7 +2493,7 @@ def _record( retry_count=retry_count, retry_errors=list(retry_errors), skipped=skipped, - cached=False, + cached=cached, fallback_used=fallback_used, fallback_tool_name=fallback_tool_name, approval=approval_record, @@ -2425,6 +2539,30 @@ def _finish(record: StepRecord) -> StepRecord: ) ) + # Step-level input contract (issue #172) — mirrors the sync lane. + if step.input_contract is not None: + input_contract_cls = step.resolved_input_contract + assert input_contract_cls is not None + contract_err = self._check_step_contract( + step=step, + step_index=step_index, + payload=inputs, + contract=input_contract_cls, + context_label="step_input_contract", + ) + if contract_err is not None: + log_step_error(_logger, step_index, step.display_name, contract_err) + return _finish( + _record( + inputs=inputs, + outputs=None, + error=contract_err, + success=False, + skipped=False, + retry_errors=[], + ) + ) + self._fire_step_start( StepStartContext( trace_id=trace_id, @@ -2467,6 +2605,66 @@ def _finish(record: StepRecord) -> StepRecord: ) ) + # Step cache lookup (issue #127 / #388) — mirrors the sync lane. Hash + # the *validated* inputs so equivalent payloads collapse onto one key; + # bypass for non-cacheable tools and during replay. + cache_key: StepCacheKey | None = None + if self._step_cache is not None and tool.cacheable and not self._local.in_replay: + try: + validated = tool.input_schema.model_validate(inputs) + except ValidationError: + validated = None # let the normal path raise + if validated is not None: + cache_key = StepCacheKey( + tool_name=tool.name, + schema_hash=tool.schema_hash, + input_value_hash=compute_input_value_hash(validated), + ) + cached_output = self._step_cache.get(cache_key) + if cached_output is not None: + if step.output_contract is not None: + cached_contract_cls = step.resolved_output_contract + assert cached_contract_cls is not None + cached_contract_err = self._check_step_contract( + step=step, + step_index=step_index, + payload=cached_output, + contract=cached_contract_cls, + context_label="step_output_contract", + ) + if cached_contract_err is not None: + log_step_error( + _logger, step_index, step.display_name, cached_contract_err + ) + return _finish( + _record( + inputs=inputs, + outputs=None, + error=cached_contract_err, + success=False, + skipped=False, + retry_errors=[], + ) + ) + log_step_end( + _logger, + step_index, + step.display_name, + cached_output, + redaction=self._redaction_policy, + ) + return _finish( + _record( + inputs=inputs, + outputs=cached_output, + error=None, + success=True, + skipped=False, + retry_errors=[], + cached=True, + ) + ) + retry_errors: list[str] = [] try: outputs = await self._invoke_tool_async( @@ -2496,6 +2694,36 @@ def _finish(record: StepRecord) -> StepRecord: outputs, redaction=self._redaction_policy, ) + + # Step-level output contract (issue #172) — mirrors the sync lane, + # validated before the cache records anything. + if step.output_contract is not None: + output_contract_cls = step.resolved_output_contract + assert output_contract_cls is not None + out_contract_err = self._check_step_contract( + step=step, + step_index=step_index, + payload=outputs, + contract=output_contract_cls, + context_label="step_output_contract", + ) + if out_contract_err is not None: + log_step_error(_logger, step_index, step.display_name, out_contract_err) + return _finish( + _record( + inputs=inputs, + outputs=None, + error=out_contract_err, + success=False, + skipped=False, + retry_errors=retry_errors, + ) + ) + + # Cache write after the output has been schema-validated by + # ``Tool.run_async`` — never store invalid output (issue #388). + if cache_key is not None and self._step_cache is not None: + self._step_cache.set(cache_key, outputs) return _finish( _record( inputs=inputs, @@ -2507,6 +2735,104 @@ def _finish(record: StepRecord) -> StepRecord: ) ) + async def _execute_subflow_step_async( + self, + step_index: int, + step: FlowStep, + context: dict[str, Any], + flow_name: str, + trace_id: str, + *, + started_at: datetime, + perf_start: float, + deadline: float | None = None, + cancel_token: CancellationToken | None = None, + ) -> StepRecord: + """Execute a composed sub-flow step on the async lane (#75 / #388). + + Async-native counterpart to :meth:`_execute_subflow_step`: resolves the + step's inputs, recursively ``await``s the named sub-flow via + :meth:`execute_flow_async`, and folds the sub-flow's final output back + into the parent context. ``deadline`` / ``cancel_token`` are forwarded + so flow-level cancellation and the wall-clock budget are observed + *between* the sub-flow's own steps, sharing one budget across the + composed run. Cycles and over-deep nesting are already rejected by + :meth:`_validate_composition` before any step runs. + """ + sub_name = step.flow_name + assert sub_name is not None # guaranteed by the caller / FlowStep validator + try: + sub_input = self._resolve_inputs(step, context, step_index) + except InputMappingError as exc: + log_step_error(_logger, step_index, sub_name, exc) + err_type, err_msg = _exc_to_strings(exc) + now = _now_utc() + return StepRecord( + step_index=step_index, + tool_name=sub_name, + flow_name=sub_name, + inputs={}, + outputs=None, + error_type=err_type, + error_message=err_msg, + success=False, + started_at=started_at, + ended_at=now, + duration_ms=(time.perf_counter() - perf_start) * 1000.0, + ) + + self._fire_step_start( + StepStartContext( + trace_id=trace_id, + flow_name=flow_name, + step_index=step_index, + tool_name=sub_name, + inputs=dict(sub_input), + started_at=started_at, + ) + ) + + saved_version = self._local.active_flow_version + try: + sub_result = await self.execute_flow_async( + sub_name, + sub_input, + deadline=deadline, + cancel_token=cancel_token, + ) + finally: + self._local.active_flow_version = saved_version + + ended_at = _now_utc() + duration_ms = (time.perf_counter() - perf_start) * 1000.0 + error_type: str | None = None + error_message: str | None = None + if not sub_result.success: + error_type = "FlowExecutionError" + error_message = f"Sub-flow '{sub_name}' failed." + record = StepRecord( + step_index=step_index, + tool_name=sub_name, + flow_name=sub_name, + inputs=sub_input, + outputs=sub_result.final_output if sub_result.success else None, + error_type=error_type, + error_message=error_message, + success=sub_result.success, + started_at=started_at, + ended_at=ended_at, + duration_ms=duration_ms, + sub_result=sub_result, + ) + self._fire_step_end( + StepEndContext( + trace_id=trace_id, + flow_name=flow_name, + step_record=record, + ) + ) + return record + async def _invoke_tool_async( self, tool: Tool, @@ -3090,6 +3416,18 @@ def resume_flow(self, trace_id: str) -> ExecutionResult: has an incompatible MAJOR component relative to the version this library writes (issue #395). """ + flow, snapshot = self._load_snapshot_for_resume(trace_id) + if isinstance(flow, DAGFlow): + return self._resume_dag_flow(flow, snapshot) + return self._resume_linear_flow(flow, snapshot) + + def _load_snapshot_for_resume(self, trace_id: str) -> tuple[AnyFlow, ExecutionSnapshot]: + """Load and drift-check the snapshot for *trace_id* (issue #128). + + Shared by :meth:`resume_flow` and :meth:`resume_flow_async` so the + checkpointer-configured / not-found / version / drift guards stay + identical across the sync and async resume lanes. + """ if self._checkpointer is None: raise CheckpointerNotConfiguredError() snapshot = self._checkpointer.load(trace_id) @@ -3127,10 +3465,42 @@ def resume_flow(self, trace_id: str) -> ExecutionResult: f"tool '{tool_name}' schema_hash changed: " f"snapshot='{snap_hash}' current='{current.schema_hash}'", ) + return flow, snapshot - if isinstance(flow, DAGFlow): - return self._resume_dag_flow(flow, snapshot) - return self._resume_linear_flow(flow, snapshot) + async def resume_flow_async(self, trace_id: str) -> ExecutionResult: + """Asynchronously resume an in-flight execution from a snapshot (#388). + + Coroutine counterpart to :meth:`resume_flow`: applies the identical + checkpointer / version / drift guards (via + :meth:`_load_snapshot_for_resume`) and then continues execution on the + async lane — ``await``-ing each remaining step (linear) or DAG level so + async-fn tools run natively. Use this to resume a run originally + executed via :meth:`execute_flow_async`. + + Args: + trace_id: Trace id of the snapshot to resume. + + Returns: + An :class:`ExecutionResult` for the (now-completed) flow. + + Raises: + CheckpointerNotConfiguredError: When no checkpointer was configured. + CheckpointNotFoundError: When no snapshot exists for *trace_id*. + FlowNotFoundError: When the snapshot's flow is no longer registered. + CheckpointDriftError: When the flow version or any tool's + ``schema_hash`` changed since the snapshot was written. + CheckpointVersionError: When the snapshot's ``snapshot_version`` is + incompatible with this library's (issue #395). + AsyncLaneUnsupportedError: When the snapshot's flow uses features + the async lane does not support (branching / decision + callbacks). + """ + flow, snapshot = self._load_snapshot_for_resume(trace_id) + self._assert_async_lane_supported(flow) + with self._run_scope(): + if isinstance(flow, DAGFlow): + return await self._resume_dag_flow_async(flow, snapshot) + return await self._resume_linear_flow_async(flow, snapshot) def _resume_linear_flow( self, @@ -3277,6 +3647,127 @@ def _resume_dag_flow( finally: self._local.resume_snapshot = None + async def _resume_linear_flow_async( + self, + flow: Any, + snapshot: ExecutionSnapshot, + ) -> ExecutionResult: + """Async counterpart to :meth:`_resume_linear_flow` (issue #388). + + Continues a linear execution from ``snapshot.completed_steps``, + ``await``-ing each remaining step so async-fn tools run natively. + """ + self._local.active_flow_version = flow.version + trace_id = snapshot.trace_id + flow_name = snapshot.flow_name + flow_started_at = snapshot.started_at + flow_t0 = time.perf_counter() + _logger.info( + "Flow '%s' (async) resuming | trace_id=%s | from_step=%d", + flow_name, + trace_id, + snapshot.completed_steps, + ) + self._fire_flow_start( + FlowStartContext( + trace_id=trace_id, + flow_name=flow_name, + flow_version=flow.version, + initial_input=dict(snapshot.initial_input), + started_at=flow_started_at, + total_steps=len(flow.steps), + ) + ) + + context: dict[str, Any] = dict(snapshot.context) + log: list[StepRecord] = list(snapshot.execution_log) + + for idx in range(snapshot.completed_steps, len(flow.steps)): + step = flow.steps[idx] + record = await self._execute_step_async(idx, step, context, flow_name, trace_id) + log.append(record) + + if not record.success: + return self._make_result( + flow_name=flow_name, + success=False, + final_output=None, + execution_log=log, + trace_id=trace_id, + started_at=flow_started_at, + perf_start=flow_t0, + initial_input=dict(snapshot.initial_input), + ) + + assert record.outputs is not None # success guarantees outputs + context.update( + apply_output_mapping( + record.outputs, + step.output_mapping, + tool_name=step.display_name, + step_index=idx, + ) + ) + + self._save_linear_snapshot( + trace_id=trace_id, + flow=flow, + initial_input=snapshot.initial_input, + started_at=flow_started_at, + context=context, + log=log, + completed_steps=idx + 1, + ) + + if flow.output_schema is not None: + validation_record = self._validate_flow_schema( + flow_name=flow_name, + payload=context, + schema=flow.output_schema, + step_index=flow_output_step_index(flow), + context_label="flow_output", + ) + if validation_record is not None: + return self._make_result( + flow_name=flow_name, + success=False, + final_output=None, + execution_log=[*log, validation_record], + trace_id=trace_id, + started_at=flow_started_at, + perf_start=flow_t0, + initial_input=dict(snapshot.initial_input), + tool_step_count=len(log), + ) + + return self._make_result( + flow_name=flow_name, + success=True, + final_output=context, + execution_log=log, + trace_id=trace_id, + started_at=flow_started_at, + perf_start=flow_t0, + initial_input=dict(snapshot.initial_input), + ) + + async def _resume_dag_flow_async( + self, + flow: DAGFlow, + snapshot: ExecutionSnapshot, + ) -> ExecutionResult: + """Async counterpart to :meth:`_resume_dag_flow` (issue #388). + + Seeds the resume slot that :meth:`_execute_dag_flow_async` consults to + reuse the snapshot's ``trace_id`` / ``started_at`` / context / log and + skip the already-completed DAG levels. + """ + self._local.resume_snapshot = snapshot + try: + return await self._execute_dag_flow_async(flow, dict(snapshot.initial_input)) + finally: + self._local.resume_snapshot = None + def _record_observed_trace(self, result: ExecutionResult) -> None: """Mirror an :class:`ExecutionResult` into the configured TraceRecorder.""" recorder = self._trace_recorder diff --git a/docs/reference/error-table.md b/docs/reference/error-table.md index 120ff5d..6e839dc 100644 --- a/docs/reference/error-table.md +++ b/docs/reference/error-table.md @@ -66,6 +66,8 @@ lacks a code, a code is duplicated, or a code is missing from this table. | `CW-E046` | `RateLimitExceededError` | A `FlowServer` rate limiter declined the call. | | `CW-E047` | `FlowAuthorizationError` | A `FlowServer` authorization callback denied the call (client-safe reason code only). | | `CW-E048` | `OpenCodeAdapterError` | An OpenCode plugin payload could not be normalized into a trace event, or a flow name has no name-safe characters. | +| `CW-E049` | `DecisionTimeoutError` | A decision callback exceeded the `DecisionPolicy.timeout_s` budget while `on_timeout="error"` was in effect. | +| `CW-E050` | `DecisionBudgetExceededError` | A flow exceeded its `DecisionPolicy.max_decisions_per_flow` budget. | ## Catching strategy diff --git a/tests/test_composition.py b/tests/test_composition.py index 7592679..3d156ef 100644 --- a/tests/test_composition.py +++ b/tests/test_composition.py @@ -21,7 +21,6 @@ ) from chainweaver.cost import CostProfile from chainweaver.exceptions import ( - AsyncLaneUnsupportedError, FlowCancelledError, FlowCompositionError, ) @@ -368,7 +367,9 @@ def test_dag_step_can_reference_subflow(self) -> None: assert result.final_output is not None assert result.final_output["b"] == 3 - async def test_async_rejects_subflow_steps(self) -> None: + async def test_async_executes_subflow_steps(self) -> None: + # Composed sub-flow steps now run on the async lane (issue #388), + # with the nested ExecutionResult attached as ``sub_result``. executor = _base_executor() executor._registry.register_flow( Flow( @@ -378,8 +379,14 @@ async def test_async_rejects_subflow_steps(self) -> None: steps=[FlowStep(flow_name="inc", input_mapping={"n": "n"})], ) ) - with pytest.raises(AsyncLaneUnsupportedError, match=r"sub-flow"): - await executor.execute_flow_async("parent_async", {"n": 1}) + result = await executor.execute_flow_async("parent_async", {"n": 1}) + assert result.success is True + assert result.final_output is not None + assert result.final_output["a"] == 2 + record = result.execution_log[0] + assert record.flow_name == "inc" + assert record.sub_result is not None + assert record.sub_result.success is True # --------------------------------------------------------------------------- diff --git a/tests/test_executor_async.py b/tests/test_executor_async.py index 00350c8..57a89db 100644 --- a/tests/test_executor_async.py +++ b/tests/test_executor_async.py @@ -283,57 +283,46 @@ async def test_default_next_rejected(self) -> None: with pytest.raises(AsyncLaneUnsupportedError, match="default_next"): await executor.execute_flow_async("routed", {"n": 1}) - async def test_subflow_step_rejected(self) -> None: - registry = FlowRegistry() - leaf = Flow( - name="leaf", - version="1.0.0", - description="", - steps=[FlowStep(tool_name="async_increment", input_mapping={"n": "n"})], - ) - parent = Flow( - name="parent", - version="1.0.0", - description="", - steps=[FlowStep(flow_name="leaf", input_mapping={"n": "n"})], - ) - registry.register_flow(leaf) - registry.register_flow(parent) - executor = FlowExecutor(registry=registry) - with pytest.raises(AsyncLaneUnsupportedError, match="sub-flow"): - await executor.execute_flow_async("parent", {"n": 1}) - async def test_error_lists_all_unsupported_constructs(self) -> None: + # Two *still*-unsupported families on the async lane (issue #388 added + # sub-flow support but branching and decision callbacks remain): a + # decision_candidates step and a branches step are both reported in one + # error, before any step runs. registry = FlowRegistry() - flow = Flow( + dag = DAGFlow( name="multi", version="1.0.0", description="", steps=[ - FlowStep( + DAGFlowStep( + step_id="a", tool_name="async_increment", input_mapping={"n": "n"}, decision_candidates=["async_increment", "async_double_value"], ), - FlowStep(flow_name="leaf", input_mapping={"n": "n"}), + DAGFlowStep( + step_id="b", + tool_name="async_increment", + input_mapping={"n": "n"}, + depends_on=["a"], + branches=[ConditionalEdge(target_step_id="c", predicate="n > 0")], + ), + DAGFlowStep( + step_id="c", + tool_name="async_increment", + input_mapping={"n": "n"}, + depends_on=["b"], + ), ], ) - leaf = Flow( - name="leaf", - version="1.0.0", - description="", - steps=[FlowStep(tool_name="async_increment", input_mapping={"n": "n"})], - ) - registry.register_flow(leaf) - registry.register_flow(flow) + registry.register_flow(dag) executor = FlowExecutor(registry=registry) with pytest.raises(AsyncLaneUnsupportedError) as exc_info: await executor.execute_flow_async("multi", {"n": 1}) - # Both unsupported constructs are reported in one error, before any step. assert len(exc_info.value.unsupported) == 2 message = str(exc_info.value) assert "decision_candidates" in message - assert "sub-flow" in message + assert "conditional branches" in message class TestExecuteFlowAsyncFallback: diff --git a/tests/test_executor_async_parity.py b/tests/test_executor_async_parity.py new file mode 100644 index 0000000..cf355ee --- /dev/null +++ b/tests/test_executor_async_parity.py @@ -0,0 +1,274 @@ +"""Async-lane parity tests for issue #388. + +Cover the step cache, checkpoint resume (``resume_flow_async``), and composed +sub-flow execution now supported by :meth:`FlowExecutor.execute_flow_async`, +mirroring the sync-lane expectations in ``tests/test_cache.py`` / +``tests/test_checkpoint.py`` / ``tests/test_composition.py``. +""" + +from __future__ import annotations + +import time +from typing import Any + +import pytest +from helpers import NumberInput, ValueInput, ValueOutput, _add_ten_fn, _double_fn +from pydantic import BaseModel + +from chainweaver import ( + Flow, + FlowExecutor, + FlowRegistry, + FlowStep, + InMemoryCheckpointer, + InMemoryStepCache, + Tool, +) +from chainweaver.exceptions import ( + CheckpointDriftError, + CheckpointerNotConfiguredError, + FlowCancelledError, +) + + +class _NIn(BaseModel): + n: int + + +class _Out(BaseModel): + value: int + + +# -------------------------------------------------------------------------- +# Async step cache (#388) +# -------------------------------------------------------------------------- + + +async def test_async_cache_hit_skips_tool_fn() -> None: + calls = {"n": 0} + + async def _counting(inp: _NIn) -> dict[str, Any]: + calls["n"] += 1 + return {"value": inp.n + 1} + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="cached_async", + version="1.0.0", + description="", + steps=[FlowStep(tool_name="inc", input_mapping={"n": "n"})], + ) + ) + cache = InMemoryStepCache() + ex = FlowExecutor(registry=registry, step_cache=cache) + ex.register_tool( + Tool(name="inc", description="", input_schema=_NIn, output_schema=_Out, fn=_counting) + ) + + first = await ex.execute_flow_async("cached_async", {"n": 5}) + assert first.success is True + assert first.execution_log[0].cached is False + assert calls["n"] == 1 + + second = await ex.execute_flow_async("cached_async", {"n": 5}) + assert second.success is True + assert second.execution_log[0].cached is True + # The tool's callable was not invoked a second time — served from cache. + assert calls["n"] == 1 + assert second.final_output is not None + assert second.final_output["value"] == 6 + + +async def test_async_cache_bypassed_for_non_cacheable_tool() -> None: + calls = {"n": 0} + + async def _counting(inp: _NIn) -> dict[str, Any]: + calls["n"] += 1 + return {"value": inp.n + 1} + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="uncached_async", + version="1.0.0", + description="", + steps=[FlowStep(tool_name="inc", input_mapping={"n": "n"})], + ) + ) + ex = FlowExecutor(registry=registry, step_cache=InMemoryStepCache()) + ex.register_tool( + Tool( + name="inc", + description="", + input_schema=_NIn, + output_schema=_Out, + fn=_counting, + cacheable=False, + ) + ) + + await ex.execute_flow_async("uncached_async", {"n": 5}) + second = await ex.execute_flow_async("uncached_async", {"n": 5}) + assert second.execution_log[0].cached is False + assert calls["n"] == 2 + + +# -------------------------------------------------------------------------- +# Async checkpoint resume (#388) +# -------------------------------------------------------------------------- + + +def _async_crash_setup() -> tuple[FlowExecutor, InMemoryCheckpointer]: + """A 2-step flow whose second step always raises, run on the async lane.""" + ck = InMemoryCheckpointer() + + def _explode(_inp: ValueInput) -> dict[str, Any]: + raise RuntimeError("simulated async crash") + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="crash_async", + version="0.1.0", + description="", + steps=[ + FlowStep(tool_name="double", input_mapping={"number": "number"}), + FlowStep(tool_name="bad", input_mapping={"value": "value"}), + ], + ) + ) + ex = FlowExecutor(registry=registry, checkpointer=ck) + ex.register_tool( + Tool( + name="double", + description="", + input_schema=NumberInput, + output_schema=ValueOutput, + fn=_double_fn, + ) + ) + ex.register_tool( + Tool( + name="bad", + description="", + input_schema=ValueInput, + output_schema=ValueOutput, + fn=_explode, + ) + ) + return ex, ck + + +async def test_resume_flow_async_resumes_after_crash() -> None: + ex, ck = _async_crash_setup() + result = await ex.execute_flow_async("crash_async", {"number": 5}) + assert result.success is False + trace_id = result.trace_id + # A snapshot was written after the first (successful) step. + assert ck.load(trace_id) is not None + + # Operator deploys a fix for the failing tool, then resumes. + ex.register_tool( + Tool( + name="bad", + description="", + input_schema=ValueInput, + output_schema=ValueOutput, + fn=_add_ten_fn, + ) + ) + resumed = await ex.resume_flow_async(trace_id) + assert resumed.success is True + assert resumed.trace_id == trace_id + assert len(resumed.execution_log) == 2 + assert resumed.execution_log[0].tool_name == "double" + assert resumed.execution_log[0].outputs == {"value": 10} + assert resumed.execution_log[1].tool_name == "bad" + + +async def test_resume_flow_async_raises_on_schema_drift() -> None: + ex, _ck = _async_crash_setup() + result = await ex.execute_flow_async("crash_async", {"number": 5}) + trace_id = result.trace_id + + # Re-register the already-completed 'double' tool with a different output + # schema so its schema_hash changes — resume must refuse on drift. + class _OtherOut(BaseModel): + value: int + extra: str = "x" + + ex.register_tool( + Tool( + name="double", + description="", + input_schema=NumberInput, + output_schema=_OtherOut, + fn=lambda inp: {"value": inp.number * 2, "extra": "x"}, + ) + ) + with pytest.raises(CheckpointDriftError): + await ex.resume_flow_async(trace_id) + + +async def test_resume_flow_async_without_checkpointer_raises() -> None: + registry = FlowRegistry() + ex = FlowExecutor(registry=registry) + with pytest.raises(CheckpointerNotConfiguredError): + await ex.resume_flow_async("nope") + + +# -------------------------------------------------------------------------- +# Async sub-flow composition: deadline / cancel forwarding (#388) +# -------------------------------------------------------------------------- + + +async def test_async_subflow_deadline_forwarded_into_subflow() -> None: + """A deadline that lands *between* the sub-flow's own steps must fire.""" + + async def _slow(inp: _NIn) -> dict[str, Any]: + time.sleep(0.05) # push past the deadline within the sub-flow + return {"value": inp.n + 1} + + async def _passthrough(inp: _Out) -> dict[str, Any]: + return {"value": inp.value} + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="sub", + version="1.0.0", + description="", + steps=[ + FlowStep(tool_name="slow", input_mapping={"n": "n"}), + FlowStep(tool_name="passthrough", input_mapping={"value": "value"}), + ], + ) + ) + registry.register_flow( + Flow( + name="parent", + version="1.0.0", + description="", + steps=[FlowStep(flow_name="sub", input_mapping={"n": "n"})], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + Tool(name="slow", description="", input_schema=_NIn, output_schema=_Out, fn=_slow) + ) + ex.register_tool( + Tool( + name="passthrough", + description="", + input_schema=_Out, + output_schema=_Out, + fn=_passthrough, + ) + ) + + deadline = time.time() + 0.02 + with pytest.raises(FlowCancelledError) as excinfo: + await ex.execute_flow_async("parent", {"n": 1}, deadline=deadline) + # The cancellation is re-anchored to the parent flow. + assert excinfo.value.flow_name == "parent" From 9d87d711186293ef421652665efd6c20f88e330a Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 22 Jun 2026 20:30:02 +0000 Subject: [PATCH 3/6] feat(executor): stream_flow_async + streamed-run cancellation (#389) Add FlowExecutor.stream_flow_async, an async generator yielding the same FlowEvent lifecycle sequence as stream_flow by driving execute_flow_async on the calling loop (no worker thread). cancel_token/deadline end the stream at the next step boundary (FlowCancelledError carries the partial); abandoning the iterator cancels the backing task. Sync stream_flow gains optional cancel_token/deadline forwarded to the worker. Refactor the stream collector to emit via a callable so both lanes share it. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_011U2ZKg2GMec7iAUDt2zezq --- AGENTS.md | 3 +- CHANGELOG.md | 12 +++ chainweaver/executor.py | 151 +++++++++++++++++++++++--- tests/test_streaming_async.py | 194 ++++++++++++++++++++++++++++++++++ 4 files changed, 343 insertions(+), 17 deletions(-) create mode 100644 tests/test_streaming_async.py diff --git a/AGENTS.md b/AGENTS.md index cf00c9e..cc8b0d5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -151,7 +151,8 @@ benchmarks/ Standalone benchmark scripts (not coverage-gated): - `RoutingDecisionAdapter(client=...)` from `chainweaver.integrations.contextweaver` (#106) → `DecisionCallback` impl that asks a `ContextweaverClient` for a `RoutingDecision` and returns the selected capability id. - `FlowExecutor.execute_flow(flow_name, initial_input, *, version=None, force=False, deadline=None, cancel_token=None)` → `ExecutionResult`. `version` (#201) targets an exact registered flow version (default: latest); the version that ran is recorded on `ExecutionResult.flow_version`. `deadline` (wall-clock `time.time()` seconds) and `cancel_token` (`CancellationToken`, #142) cooperatively cancel **between** steps / DAG levels — never inside a tool — raising `FlowCancelledError` with the partial result. - `FlowExecutor.execute_flow_async(flow_name, initial_input, *, version=None, force=False, deadline=None, cancel_token=None)` → `Awaitable[ExecutionResult]` (#80); async-native counterpart of `execute_flow`. Dispatches each step through `Tool.run_async` so async-fn tools (e.g. those produced by `chainweaver.mcp.MCPToolAdapter`) execute on the calling loop and sync-fn tools are offloaded to `asyncio.to_thread`. Supports linear and DAG flows with retries, middleware, and on_error policies; honours `version` / `deadline` / `cancel_token`; executes composed `flow_name` sub-flow steps, consults the step cache, and writes checkpoints — resume via `resume_flow_async(trace_id)` (#388). Still rejects conditional branching (#9) and `decision_candidates` (#102). -- `FlowExecutor.stream_flow(flow_name, initial_input, *, force=False)` → `Iterator[FlowEvent]` (#134); yields `kind="flow_start"` → (`step_start` → `step_end`)* → `flow_end` events as the flow runs on a worker thread. Cancellation is not supported for the sync variant; the background thread runs to completion. +- `FlowExecutor.stream_flow(flow_name, initial_input, *, force=False, deadline=None, cancel_token=None)` → `Iterator[FlowEvent]` (#134); yields `kind="flow_start"` → (`step_start` → `step_end`)* → `flow_end` events as the flow runs on a worker thread. A `deadline` / `cancel_token` is checked at step boundaries on the worker (#389); abandoning the iterator still lets the in-flight step run to completion. +- `FlowExecutor.stream_flow_async(flow_name, initial_input, *, force=False, deadline=None, cancel_token=None)` → `AsyncIterator[FlowEvent]` (#389); async-native counterpart driving `execute_flow_async` on the calling loop (no worker thread). Same event order; `cancel_token` / `deadline` end the stream promptly at the next step boundary by raising `FlowCancelledError` (partial on `.result`), and abandoning the iterator cancels the backing task. Async-lane feature support applies (#388). - `FlowExecutor(..., step_cache=...)` → memoize step outputs across runs (#127); keyed by `(tool_name, schema_hash, input_value_hash)`. Cache hits skip `Tool.fn` entirely (including retries and timeout) and surface as `StepRecord.cached=True`. Tools mark themselves `cacheable=False` to always run (side-effects, external state). `replay_flow` always bypasses the cache. - `FlowExecutor(..., checkpointer=..., delete_on_success=True)` → crash-resume (#128); writes an `ExecutionSnapshot` after every successful linear step or DAG level. `FlowExecutor.resume_flow(trace_id)` (or `resume_flow_async(trace_id)` for runs started on the async lane, #388) validates the snapshot's flow version and tool `schema_hash` values against the current registry — drift raises `CheckpointDriftError` — then continues execution with the original `trace_id`. Snapshots are deleted on terminal success when `delete_on_success=True` (the default); preserved on failure for operator-driven retry. - `OTelTraceExporter(tracer=...)` from `chainweaver.integrations.opentelemetry` (#126) → emits OpenTelemetry spans as a `FlowExecutorMiddleware`: one parent `chainweaver.flow.{name}` span + one child `chainweaver.tool.{name}` span per `StepRecord`. After-the-fact export of a completed `ExecutionResult` via `export_result_to_otel(result, tracer=...)`. Optional extra: `pip install 'chainweaver[otel]'`. diff --git a/CHANGELOG.md b/CHANGELOG.md index 6517312..885cc55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **`stream_flow_async` and streamed-run cancellation** (#389): a new + `FlowExecutor.stream_flow_async(...)` async generator yields the same + `flow_start → (step_start → step_end)* → flow_end` `FlowEvent` sequence as + `stream_flow`, driving `execute_flow_async` directly on the calling loop + (no worker thread). A `cancel_token` / `deadline` ends the stream promptly + at the next step boundary, raising `FlowCancelledError` with the partial run + on `.result`; abandoning the iterator cancels the backing task. The sync + `stream_flow` also gains optional `cancel_token` / `deadline` parameters, + checked at step boundaries on the worker thread (the in-flight step still + completes). Event ordering and the "observability never aborts a flow" + contract are unchanged. + - **Async-lane parity: cache, checkpoint resume, sub-flow composition** (#388): `execute_flow_async` now consults the step cache (records `StepRecord.cached=True` on hits and skips `Tool.run_async`), writes diff --git a/chainweaver/executor.py b/chainweaver/executor.py index 4efe8c3..dd34ede 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -20,7 +20,7 @@ import threading import time import uuid -from collections.abc import Callable, Iterable, Iterator +from collections.abc import AsyncIterator, Callable, Iterable, Iterator from datetime import datetime, timezone from enum import Enum from graphlib import TopologicalSorter @@ -215,15 +215,19 @@ def copy(self) -> _RunScopedState: class _StreamCollectorMiddleware(BaseMiddleware): """Per-call middleware that pushes lifecycle events onto a queue. - Used by :meth:`FlowExecutor.stream_flow` to bridge the lifecycle - hook surface to a generator yielding :class:`FlowEvent` payloads. + Used by :meth:`FlowExecutor.stream_flow` (sync, thread-backed queue) and + :meth:`FlowExecutor.stream_flow_async` (asyncio queue) to bridge the + lifecycle hook surface to a generator yielding :class:`FlowEvent` + payloads. *emit* is the queue's enqueue callable — + :meth:`queue.Queue.put` for the sync variant, :meth:`asyncio.Queue.put_nowait` + for the async one (the queues are unbounded, so neither blocks). """ - def __init__(self, events: queue.Queue[FlowEvent | _StreamSentinel]) -> None: - self._events = events + def __init__(self, emit: Callable[[FlowEvent], None]) -> None: + self._emit = emit def on_flow_start(self, ctx: FlowStartContext) -> None: - self._events.put( + self._emit( FlowEvent( kind="flow_start", flow_name=ctx.flow_name, @@ -236,7 +240,7 @@ def on_flow_start(self, ctx: FlowStartContext) -> None: ) def on_step_start(self, ctx: StepStartContext) -> None: - self._events.put( + self._emit( FlowEvent( kind="step_start", flow_name=ctx.flow_name, @@ -249,7 +253,7 @@ def on_step_start(self, ctx: StepStartContext) -> None: ) def on_step_end(self, ctx: StepEndContext) -> None: - self._events.put( + self._emit( FlowEvent( kind="step_end", flow_name=ctx.flow_name, @@ -262,7 +266,7 @@ def on_step_end(self, ctx: StepEndContext) -> None: ) def on_flow_end(self, ctx: FlowEndContext) -> None: - self._events.put( + self._emit( FlowEvent( kind="flow_end", flow_name=ctx.flow_name, @@ -2969,6 +2973,8 @@ def stream_flow( initial_input: dict[str, Any], *, force: bool = False, + deadline: float | None = None, + cancel_token: CancellationToken | None = None, ) -> Iterator[FlowEvent]: """Execute a flow and yield :class:`FlowEvent` lifecycle events (#134). @@ -2986,17 +2992,23 @@ def stream_flow( The flow runs on a background worker thread; events are delivered through a synchronized queue. - **Cancellation is not supported** for the sync variant. If the - consumer breaks out of the iteration (or otherwise lets the + **Consumer-driven cancellation is limited** for the sync variant. + If the consumer breaks out of the iteration (or otherwise lets the generator be garbage-collected), the generator's ``finally`` block blocks on ``thread.join()`` until the background flow finishes — for a 10-step flow with a long step 3 that means the caller's "stop iterating" intent is silently translated into "block here until everything completes". A ``WARNING`` is logged via the ``chainweaver.executor`` logger when this - happens so the behavior shows up in production traces. For - proper cancellation use the async variant once issue #80 - lands. + happens so the behavior shows up in production traces. + + A ``deadline`` / ``cancel_token`` *can* be supplied (issue #389): + they are forwarded to the underlying :meth:`execute_flow` and + checked at step boundaries on the worker thread, so an in-flight + tool always finishes but the flow stops before the next step — + the same cooperative contract as :meth:`execute_flow`. For + cancellation that also stops the consumer's ``await`` promptly, + use :meth:`stream_flow_async`. Hook exceptions and middleware exceptions still follow the catch-and-log contract from :class:`FlowExecutorMiddleware`: a @@ -3008,6 +3020,12 @@ def stream_flow( step. force: When ``True``, bypass the status guard and execute even if the flow is ``NEEDS_REVIEW`` or ``DISABLED``. + deadline: Optional wall-clock deadline (issue #142 / #389), + forwarded to :meth:`execute_flow` and checked between steps + on the worker thread. + cancel_token: Optional :class:`CancellationToken` (issue #142 / + #389), forwarded and checked between steps on the worker + thread. Yields: :class:`~chainweaver.events.FlowEvent` instances in the order @@ -3021,7 +3039,7 @@ def stream_flow( *force* is ``False``. Same re-raise behavior. """ events: queue.Queue[FlowEvent | _StreamSentinel] = queue.Queue() - collector = _StreamCollectorMiddleware(events) + collector = _StreamCollectorMiddleware(events.put) exc_holder: list[BaseException] = [] # Register the event collector as *run-scoped* middleware on the worker @@ -3033,7 +3051,13 @@ def stream_flow( def _worker() -> None: try: with self._scoped_middleware(collector): - self.execute_flow(flow_name, initial_input, force=force) + self.execute_flow( + flow_name, + initial_input, + force=force, + deadline=deadline, + cancel_token=cancel_token, + ) except BaseException as exc: exc_holder.append(exc) finally: @@ -3070,6 +3094,101 @@ def _worker() -> None: # run-scoped middleware slot (issue #336) and was popped when the # worker's ``_scoped_middleware`` context exited. + async def stream_flow_async( + self, + flow_name: str, + initial_input: dict[str, Any], + *, + force: bool = False, + deadline: float | None = None, + cancel_token: CancellationToken | None = None, + ) -> AsyncIterator[FlowEvent]: + """Async-native streaming counterpart to :meth:`stream_flow` (#389). + + Drives :meth:`execute_flow_async` directly on the calling event loop — + no worker thread — and yields the same lifecycle events in the same + order:: + + flow_start + (step_start, step_end)* # one pair per executed step + flow_end # on normal completion + + Steps that fail before input resolution (tool-not-found, input-mapping) + emit ``step_end`` without a preceding ``step_start`` — the same + middleware contract as the sync variant. + + Unlike :meth:`stream_flow`, cancellation is prompt and cooperative: + passing a ``cancel_token`` (or a ``deadline``) ends the stream at the + next step boundary by raising + :class:`~chainweaver.exceptions.FlowCancelledError` from the iterator — + its :attr:`~chainweaver.exceptions.FlowCancelledError.result` carries + the partial run. No ``flow_end`` is emitted on cancellation (execution + raised before the flow-end hook). If the consumer stops iterating + early, the backing :meth:`execute_flow_async` task is cancelled. + + The async lane's feature support applies (issue #388): composed + sub-flows, the step cache, and checkpoints work; conditional branching + and decision callbacks raise :class:`AsyncLaneUnsupportedError` before + any event is emitted. + + Args: + flow_name: Name of the flow to execute. + initial_input: Initial key/value context passed to the first step. + force: When ``True``, bypass the status guard. + deadline: Optional wall-clock deadline (issue #142); checked + between steps / DAG levels. + cancel_token: Optional :class:`CancellationToken` (issue #142); + checked between steps / DAG levels. + + Yields: + :class:`~chainweaver.events.FlowEvent` instances in the order above. + + Raises: + FlowCancelledError: When *deadline* has passed or *cancel_token* is + cancelled at a step boundary. + AsyncLaneUnsupportedError: When the flow uses async-unsupported + features (raised before any event is yielded). + """ + events: asyncio.Queue[FlowEvent | _StreamSentinel] = asyncio.Queue() + collector = _StreamCollectorMiddleware(events.put_nowait) + exc_holder: list[BaseException] = [] + + async def _runner() -> None: + # Register the collector as run-scoped middleware on this task's + # context (issue #336): ``execute_flow_async`` inherits it through + # its own ``_run_scope`` copy, so concurrent streams never see each + # other's events and the shared middleware list is never mutated. + try: + with self._scoped_middleware(collector): + await self.execute_flow_async( + flow_name, + initial_input, + force=force, + deadline=deadline, + cancel_token=cancel_token, + ) + except BaseException as exc: # surfaced from the consumer below + exc_holder.append(exc) + finally: + events.put_nowait(_STREAM_SENTINEL) + + task = asyncio.create_task(_runner()) + try: + while True: + item = await events.get() + if isinstance(item, _StreamSentinel): + if exc_holder: + raise exc_holder[0] + return + yield item + finally: + # Consumer stopped early (``break`` / ``aclose``) or an event was + # re-raised: cancel the backing run so it cannot outlive the stream. + if not task.done(): + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ diff --git a/tests/test_streaming_async.py b/tests/test_streaming_async.py new file mode 100644 index 0000000..440dd59 --- /dev/null +++ b/tests/test_streaming_async.py @@ -0,0 +1,194 @@ +"""Tests for FlowExecutor.stream_flow_async + sync stream_flow cancellation (#389).""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any + +import pytest +from helpers import NumberInput, ValueInput, ValueOutput, _add_ten_fn, _double_fn +from pydantic import BaseModel + +from chainweaver import CancellationToken +from chainweaver.events import FlowEvent +from chainweaver.exceptions import FlowCancelledError +from chainweaver.executor import FlowExecutor +from chainweaver.flow import DAGFlow, DAGFlowStep, Flow, FlowStep +from chainweaver.registry import FlowRegistry +from chainweaver.tools import Tool + + +class _NIn(BaseModel): + n: int + + +class _Out(BaseModel): + value: int + + +async def _async_inc(inp: _NIn) -> dict[str, Any]: + await asyncio.sleep(0) + return {"value": inp.n + 1} + + +async def _async_double(inp: _Out) -> dict[str, Any]: + await asyncio.sleep(0) + return {"value": inp.value * 2} + + +def _linear_executor() -> FlowExecutor: + registry = FlowRegistry() + registry.register_flow( + Flow( + name="lin", + version="1.0.0", + description="", + steps=[ + FlowStep(tool_name="inc", input_mapping={"n": "n"}), + FlowStep(tool_name="dbl", input_mapping={"value": "value"}), + ], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + Tool(name="inc", description="", input_schema=_NIn, output_schema=_Out, fn=_async_inc) + ) + ex.register_tool( + Tool(name="dbl", description="", input_schema=_Out, output_schema=_Out, fn=_async_double) + ) + return ex + + +async def test_async_stream_event_order_linear() -> None: + ex = _linear_executor() + kinds = [e.kind async for e in ex.stream_flow_async("lin", {"n": 4})] + assert kinds == ["flow_start", "step_start", "step_end", "step_start", "step_end", "flow_end"] + + +async def test_async_stream_flow_end_carries_result() -> None: + ex = _linear_executor() + events = [e async for e in ex.stream_flow_async("lin", {"n": 4})] + end = events[-1] + assert end.kind == "flow_end" + assert end.result is not None + assert end.result.success is True + # (4 + 1) * 2 == 10 + assert end.result.final_output is not None + assert end.result.final_output["value"] == 10 + + +async def test_async_stream_event_order_dag() -> None: + registry = FlowRegistry() + registry.register_flow( + DAGFlow( + name="dag", + version="1.0.0", + description="", + steps=[ + DAGFlowStep(step_id="a", tool_name="inc", input_mapping={"n": "n"}), + DAGFlowStep( + step_id="b", + tool_name="dbl", + input_mapping={"value": "value"}, + depends_on=["a"], + ), + ], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + Tool(name="inc", description="", input_schema=_NIn, output_schema=_Out, fn=_async_inc) + ) + ex.register_tool( + Tool(name="dbl", description="", input_schema=_Out, output_schema=_Out, fn=_async_double) + ) + kinds = [e.kind async for e in ex.stream_flow_async("dag", {"n": 3})] + assert kinds == ["flow_start", "step_start", "step_end", "step_start", "step_end", "flow_end"] + + +async def test_async_stream_events_are_json_serializable() -> None: + ex = _linear_executor() + async for event in ex.stream_flow_async("lin", {"n": 1}): + # Round-trips through JSON for non-Python stream consumers. + restored = FlowEvent.model_validate_json(event.model_dump_json()) + assert restored.kind == event.kind + + +async def test_async_stream_cancels_at_step_boundary() -> None: + token = CancellationToken() + token.cancel() # cancelled before the first step boundary + ex = _linear_executor() + collected: list[str] = [] + with pytest.raises(FlowCancelledError) as excinfo: + async for event in ex.stream_flow_async("lin", {"n": 1}, cancel_token=token): + collected.append(event.kind) + # The stream opened (flow_start) and emitted the terminal flow_end carrying + # the partial result, then stopped before running step 0 — no step events. + assert collected == ["flow_start", "flow_end"] + # The partial result is available on the raised error. + assert excinfo.value.result is not None + + +async def test_async_stream_deadline_ends_stream() -> None: + ex = _linear_executor() + with pytest.raises(FlowCancelledError): + async for _ in ex.stream_flow_async("lin", {"n": 1}, deadline=time.time() - 1.0): + pass + + +# --------------------------------------------------------------------------- +# Sync stream_flow now honours cancel_token at step boundaries (#389) +# --------------------------------------------------------------------------- + + +def _sync_executor() -> FlowExecutor: + registry = FlowRegistry() + registry.register_flow( + Flow( + name="s", + version="0.1.0", + description="", + steps=[ + FlowStep(tool_name="double", input_mapping={"number": "number"}), + FlowStep(tool_name="add_ten", input_mapping={"value": "value"}), + ], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + Tool( + name="double", + description="", + input_schema=NumberInput, + output_schema=ValueOutput, + fn=_double_fn, + ) + ) + ex.register_tool( + Tool( + name="add_ten", + description="", + input_schema=ValueInput, + output_schema=ValueOutput, + fn=_add_ten_fn, + ) + ) + return ex + + +def test_sync_stream_flow_honors_cancel_token() -> None: + token = CancellationToken() + token.cancel() + ex = _sync_executor() + collected: list[str] = [] + with pytest.raises(FlowCancelledError): + for event in ex.stream_flow("s", {"number": 1}, cancel_token=token): + collected.append(event.kind) + assert collected == ["flow_start", "flow_end"] + + +def test_sync_stream_flow_unchanged_without_cancel_token() -> None: + ex = _sync_executor() + kinds = [e.kind for e in ex.stream_flow("s", {"number": 4})] + assert kinds == ["flow_start", "step_start", "step_end", "step_start", "step_end", "flow_end"] From 8bb4c02f1a15cdce0fa5d5026ef3b3ef25535420 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 22 Jun 2026 20:41:04 +0000 Subject: [PATCH 4/6] feat(tools,executor): StreamingTool + step_chunk propagation (#320) Add StreamingTool (a Tool subclass) and ToolChunk: a streaming tool yields intermediate chunks plus a terminal is_final chunk carrying the schema-validated assembled output. stream_flow_async surfaces each chunk as a new FlowEvent(kind='step_chunk', chunk=...) between step_start and step_end, via a new additive on_step_chunk middleware hook + StepChunkContext. Streaming tools stay fully backward compatible on non-streaming paths (run/run_async/sync executor drain to the assembled output). Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_011U2ZKg2GMec7iAUDt2zezq --- AGENTS.md | 6 +- CHANGELOG.md | 15 +++ chainweaver/__init__.py | 6 +- chainweaver/events.py | 9 +- chainweaver/executor.py | 136 +++++++++++++++++++++- chainweaver/middleware.py | 40 +++++++ chainweaver/tools.py | 130 ++++++++++++++++++++- tests/fixtures/public_api.json | 33 +++++- tests/test_streaming_tools.py | 203 +++++++++++++++++++++++++++++++++ 9 files changed, 568 insertions(+), 10 deletions(-) create mode 100644 tests/test_streaming_tools.py diff --git a/AGENTS.md b/AGENTS.md index cc8b0d5..19e5731 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -45,7 +45,7 @@ chainweaver/ ├── contracts.py ToolSafetyContract + SideEffectLevel/StabilityLevel/DeterminismLevel enums + merge_safety() + side_effect_exceeds() (#356) + evaluate_predicate() — determinism + operational safety vocabulary (#19, #125, #293, #9, #8) ├── approvals.py ApprovalCallback Protocol + ApprovalContext/ApprovalDecision/ApprovalRecord + coerce_approval_callback — execution-time ToolSafetyContract enforcement seam (#356); mirrors decisions.py ├── decorators.py @tool decorator for zero-boilerplate tool definition -├── tools.py Tool class: named callable with Pydantic I/O schemas + schema_hash + safety contract (#19) + metadata provenance (#358/#359/#371) + dry_run_fn/run_dry (#357); Tool.from_flow() wraps a Flow as a Tool (#24) with derived safety (#125) +├── tools.py Tool class: named callable with Pydantic I/O schemas + schema_hash + safety contract (#19) + metadata provenance (#358/#359/#371) + dry_run_fn/run_dry (#357); Tool.from_flow() wraps a Flow as a Tool (#24) with derived safety (#125); StreamingTool + ToolChunk for streamed output via run_streaming (#320) ├── flow.py FlowStep (+ output_mapping #386) + Flow + DAGFlow (+ dynamic_params #316) + FlowStatus + FlowLifecycle + FlowGovernance + DriftInfo + ConditionalEdge (#9) + determinism_level property (#8) + ContextCollisionPolicy / on_context_collision (#337) ├── step_index.py Named sentinels for flow input/output validation records (#339) ├── _pointer.py Dependency-free RFC-6901 JSON pointer resolver shared by executor input_mapping (#387) and contrib json_pluck @@ -58,8 +58,8 @@ chainweaver/ ├── _execution/ Internal, no-I/O execution collaborators shared by both lanes (#330, #331); banned from importing LLM/network/random — see invariants │ ├── __init__.py Re-exports merge_step_outputs + apply_output_mapping │ └── context.py merge_step_outputs + apply_output_mapping: single context-merge honouring on_context_collision (#337) and output_mapping (#386) -├── middleware.py FlowExecutorMiddleware Protocol + lifecycle context models + BaseMiddleware (#131) -├── events.py FlowEvent streamable lifecycle payload yielded by FlowExecutor.stream_flow (#134) +├── middleware.py FlowExecutorMiddleware Protocol + lifecycle context models + BaseMiddleware (#131); optional on_step_chunk hook + StepChunkContext for streaming steps (#320) +├── events.py FlowEvent streamable lifecycle payload yielded by FlowExecutor.stream_flow / stream_flow_async (#134, #389) — incl. kind="step_chunk" carrying a ToolChunk for streaming tools (#320) ├── cache.py StepCache Protocol + InMemoryStepCache + FileStepCache + StepCacheKey (#127) ├── checkpoint.py Checkpointer Protocol + ExecutionSnapshot + InMemoryCheckpointer + FileCheckpointer (#128) ├── integrations/ Optional third-party adapters (each guards its extra import) diff --git a/CHANGELOG.md b/CHANGELOG.md index 885cc55..70fdbdb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- **Streaming tool output propagation** (#320): a new + `chainweaver.StreamingTool` (a `Tool` subclass) produces its output as a + stream of `chainweaver.ToolChunk` objects via an async `run_streaming` + generator — zero or more intermediate chunks followed by one terminal + `is_final=True` chunk whose data is the schema-validated assembled output. + `stream_flow_async` surfaces each chunk as a new `FlowEvent(kind="step_chunk", + chunk=...)`, interleaved between `step_start` and `step_end`, so real-time + pipelines (voice, A2A, SSE) can consume partial output as it is produced. + Streaming tools are fully backward compatible: on the non-streaming paths + (`run` / `run_async` / sync `execute_flow` / non-streamed + `execute_flow_async`) they transparently drain to the assembled output and + behave like any other tool. A new optional `on_step_chunk` middleware hook + (with a `StepChunkContext`) receives chunks; it is additive and dispatched + only to middleware that define it, so existing middleware are unaffected. + - **`stream_flow_async` and streamed-run cancellation** (#389): a new `FlowExecutor.stream_flow_async(...)` async generator yields the same `flow_start → (step_start → step_end)* → flow_end` `FlowEvent` sequence as diff --git a/chainweaver/__init__.py b/chainweaver/__init__.py index 0a5fce6..3df7ed4 100644 --- a/chainweaver/__init__.py +++ b/chainweaver/__init__.py @@ -196,6 +196,7 @@ FlowEndContext, FlowExecutorMiddleware, FlowStartContext, + StepChunkContext, StepEndContext, StepStartContext, ) @@ -259,7 +260,7 @@ from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index from chainweaver.storage import FileStore, InMemoryStore, RegistryStore from chainweaver.testing.replay import FixtureStaleError -from chainweaver.tools import Tool +from chainweaver.tools import StreamingTool, Tool, ToolChunk from chainweaver.traces import ( AgentTraceEvent, BacktestMismatch, @@ -444,15 +445,18 @@ "StabilityLevel", "StepCache", "StepCacheKey", + "StepChunkContext", "StepDiff", "StepEndContext", "StepPlan", "StepRecord", "StepStartContext", + "StreamingTool", "StructuredLLMFn", "Suggestion", "Tool", "ToolChain", + "ToolChunk", "ToolDefinitionError", "ToolDescriptionProposal", "ToolNotFoundError", diff --git a/chainweaver/events.py b/chainweaver/events.py index 109fe9b..5920ff2 100644 --- a/chainweaver/events.py +++ b/chainweaver/events.py @@ -43,11 +43,13 @@ from pydantic import BaseModel, ConfigDict +from chainweaver.tools import ToolChunk + if TYPE_CHECKING: # pragma: no cover — import-cycle guard from chainweaver.executor import ExecutionResult, StepRecord -FlowEventKind = Literal["flow_start", "step_start", "step_end", "flow_end"] +FlowEventKind = Literal["flow_start", "step_start", "step_chunk", "step_end", "flow_end"] class FlowEvent(BaseModel): @@ -66,6 +68,8 @@ class FlowEvent(BaseModel): +--------------+----------------------------------------------------+ | step_start | ``step_index``, ``tool_name``, ``inputs`` | +--------------+----------------------------------------------------+ + | step_chunk | ``step_index``, ``tool_name``, ``chunk`` | + +--------------+----------------------------------------------------+ | step_end | ``step_index``, ``tool_name``, ``step_record`` | +--------------+----------------------------------------------------+ | flow_end | ``result`` | @@ -93,6 +97,8 @@ class FlowEvent(BaseModel): step_record: Final :class:`~chainweaver.executor.StepRecord` for the step (``step_end`` only) — inspect ``step_record.success`` to branch. + chunk: The :class:`~chainweaver.tools.ToolChunk` produced by a + streaming step (``step_chunk`` only, issue #320). result: Full :class:`~chainweaver.executor.ExecutionResult` (``flow_end`` only). """ @@ -110,6 +116,7 @@ class FlowEvent(BaseModel): tool_name: str | None = None inputs: dict[str, Any] | None = None step_record: StepRecord | None = None + chunk: ToolChunk | None = None result: ExecutionResult | None = None diff --git a/chainweaver/executor.py b/chainweaver/executor.py index dd34ede..0ac66de 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -78,6 +78,7 @@ PredicateSyntaxError, SafetyCeilingError, SchemaValidationError, + ToolDefinitionError, ToolNotFoundError, ToolOutputSizeError, ToolTimeoutError, @@ -104,13 +105,14 @@ FlowEndContext, FlowExecutorMiddleware, FlowStartContext, + StepChunkContext, StepEndContext, StepStartContext, ) from chainweaver.observation import TraceRecorder from chainweaver.registry import AnyFlow, FlowRegistry from chainweaver.step_index import FLOW_INPUT_STEP_INDEX, flow_output_step_index -from chainweaver.tools import Tool +from chainweaver.tools import StreamingTool, Tool _logger = get_logger("chainweaver.executor") _middleware_logger = get_logger("chainweaver.middleware") @@ -252,6 +254,19 @@ def on_step_start(self, ctx: StepStartContext) -> None: ) ) + def on_step_chunk(self, ctx: StepChunkContext) -> None: + self._emit( + FlowEvent( + kind="step_chunk", + flow_name=ctx.flow_name, + trace_id=ctx.trace_id, + timestamp=_now_utc(), + step_index=ctx.step_index, + tool_name=ctx.tool_name, + chunk=ctx.chunk, + ) + ) + def on_step_end(self, ctx: StepEndContext) -> None: self._emit( FlowEvent( @@ -890,7 +905,11 @@ def remove_middleware(self, middleware: FlowExecutorMiddleware) -> None: def _fire_hook( self, hook: str, - ctx: FlowStartContext | StepStartContext | StepEndContext | FlowEndContext, + ctx: FlowStartContext + | StepStartContext + | StepChunkContext + | StepEndContext + | FlowEndContext, ) -> None: """Dispatch *hook* to every registered middleware, catching exceptions. @@ -987,6 +1006,9 @@ def _fire_step_start(self, ctx: StepStartContext) -> None: def _fire_step_end(self, ctx: StepEndContext) -> None: self._fire_hook("on_step_end", ctx) + def _fire_step_chunk(self, ctx: StepChunkContext) -> None: + self._fire_hook("on_step_chunk", ctx) + def _fire_flow_end(self, ctx: FlowEndContext) -> None: self._fire_hook("on_flow_end", ctx) @@ -2609,6 +2631,25 @@ def _finish(record: StepRecord) -> StepRecord: ) ) + # Streaming tool (issue #320): consume the chunk stream, emitting a + # ``step_chunk`` event per chunk (surfaced by ``stream_flow_async``), + # and use the terminal chunk's assembled — already output-schema + # validated — data as the step output. The step cache is bypassed + # (streaming tools are I/O-bound and typically non-deterministic). + if isinstance(tool, StreamingTool): + return _finish( + await self._run_streaming_step_async( + tool=tool, + step=step, + step_index=step_index, + inputs=inputs, + flow_name=flow_name, + trace_id=trace_id, + tool_attempts=tool_attempts, + record_fn=_record, + ) + ) + # Step cache lookup (issue #127 / #388) — mirrors the sync lane. Hash # the *validated* inputs so equivalent payloads collapse onto one key; # bypass for non-cacheable tools and during replay. @@ -2739,6 +2780,97 @@ def _finish(record: StepRecord) -> StepRecord: ) ) + async def _run_streaming_step_async( + self, + *, + tool: StreamingTool, + step: FlowStep, + step_index: int, + inputs: dict[str, Any], + flow_name: str, + trace_id: str, + tool_attempts: list[int], + record_fn: Callable[..., StepRecord], + ) -> StepRecord: + """Drive a :class:`StreamingTool`, emitting per-chunk events (#320). + + Consumes :meth:`StreamingTool.run_streaming`, firing an + ``on_step_chunk`` hook per :class:`ToolChunk`, and returns the step's + :class:`StepRecord` built from the terminal chunk's assembled output. + Tool failures route through the same async ``on_error`` machinery as a + normal step (fail / skip / fallback). + """ + tool_attempts[0] += 1 + retry_errors: list[str] = [] + try: + final_output: dict[str, Any] | None = None + async for chunk in tool.run_streaming(inputs): + self._fire_step_chunk( + StepChunkContext( + trace_id=trace_id, + flow_name=flow_name, + step_index=step_index, + tool_name=step.display_name, + chunk=chunk, + ) + ) + if chunk.is_final: + final_output = chunk.data + if final_output is None: + raise ToolDefinitionError( + tool.name, "streaming tool produced no terminal (is_final=True) chunk." + ) + except Exception as exc: + wrapped = self._wrap_tool_exception(step, step_index, exc) + log_step_error(_logger, step_index, step.display_name, wrapped) + return await self._apply_on_error_async( + step=step, + step_index=step_index, + inputs=inputs, + wrapped_error=wrapped, + retry_errors=retry_errors, + make_record=record_fn, + ) + + log_step_end( + _logger, + step_index, + step.display_name, + final_output, + redaction=self._redaction_policy, + ) + + # Step-level output contract (issue #172) — mirrors the non-streaming path. + if step.output_contract is not None: + output_contract_cls = step.resolved_output_contract + assert output_contract_cls is not None + out_contract_err = self._check_step_contract( + step=step, + step_index=step_index, + payload=final_output, + contract=output_contract_cls, + context_label="step_output_contract", + ) + if out_contract_err is not None: + log_step_error(_logger, step_index, step.display_name, out_contract_err) + return record_fn( + inputs=inputs, + outputs=None, + error=out_contract_err, + success=False, + skipped=False, + retry_errors=retry_errors, + ) + + return record_fn( + inputs=inputs, + outputs=final_output, + error=None, + success=True, + skipped=False, + retry_errors=retry_errors, + ) + async def _execute_subflow_step_async( self, step_index: int, diff --git a/chainweaver/middleware.py b/chainweaver/middleware.py index daf9834..5dd1afb 100644 --- a/chainweaver/middleware.py +++ b/chainweaver/middleware.py @@ -38,6 +38,8 @@ from pydantic import BaseModel, ConfigDict +from chainweaver.tools import ToolChunk + if TYPE_CHECKING: # pragma: no cover — import-cycle guard from chainweaver.executor import ExecutionResult, StepRecord @@ -120,6 +122,33 @@ class StepEndContext(BaseModel): step_record: StepRecord +class StepChunkContext(BaseModel): + """Context passed to ``on_step_chunk`` for a streaming step (issue #320). + + Fired once per :class:`~chainweaver.tools.ToolChunk` produced by a + :class:`~chainweaver.tools.StreamingTool` while it runs under + :meth:`~chainweaver.executor.FlowExecutor.stream_flow_async` — interleaved + between that step's ``on_step_start`` and ``on_step_end``. Non-streaming + tools and the non-streaming execution paths never fire it, so this hook is + purely additive. + + Attributes: + trace_id: UUID4 hex string matching the parent flow's trace id. + flow_name: Name of the flow being executed. + step_index: Zero-based position of the streaming step in the flow. + tool_name: Name of the streaming tool producing the chunk. + chunk: The :class:`~chainweaver.tools.ToolChunk` just produced. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + + trace_id: str + flow_name: str + step_index: int + tool_name: str + chunk: ToolChunk + + class FlowEndContext(BaseModel): """Context passed to :meth:`FlowExecutorMiddleware.on_flow_end`. @@ -202,6 +231,16 @@ def on_step_start(self, ctx: StepStartContext) -> None: def on_step_end(self, ctx: StepEndContext) -> None: """Default no-op.""" + def on_step_chunk(self, ctx: StepChunkContext) -> None: + """Default no-op (issue #320). + + Fired per :class:`~chainweaver.tools.ToolChunk` while a streaming step + runs under ``stream_flow_async``. Optional and additive — it is not + part of the :class:`FlowExecutorMiddleware` Protocol, so existing + middleware are unaffected; the executor dispatches it only to + middleware that define it. + """ + def on_flow_end(self, ctx: FlowEndContext) -> None: """Default no-op.""" @@ -211,6 +250,7 @@ def on_flow_end(self, ctx: FlowEndContext) -> None: "FlowEndContext", "FlowExecutorMiddleware", "FlowStartContext", + "StepChunkContext", "StepEndContext", "StepStartContext", ] diff --git a/chainweaver/tools.py b/chainweaver/tools.py index 9f84022..42fe9c0 100644 --- a/chainweaver/tools.py +++ b/chainweaver/tools.py @@ -29,13 +29,13 @@ import hashlib import inspect import json -from collections.abc import Awaitable, Callable +from collections.abc import AsyncIterator, Awaitable, Callable from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError from functools import cached_property from typing import TYPE_CHECKING, Any -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from chainweaver.compat import schema_fingerprint from chainweaver.contracts import ToolSafetyContract, merge_safety @@ -78,6 +78,32 @@ def _is_async_callable(fn: Callable[..., Any]) -> bool: return bound is not None and inspect.iscoroutinefunction(bound) +class ToolChunk(BaseModel): + """One streamed increment of a :class:`StreamingTool`'s output (issue #320). + + A streaming tool yields a sequence of ``ToolChunk`` objects: zero or more + intermediate chunks (``is_final=False``) carrying partial data — tokens, an + A2A artifact, an SSE event — followed by exactly one terminal chunk + (``is_final=True``) whose :attr:`data` is the assembled output ``dict`` that + is validated against the tool's ``output_schema`` and merged into the flow + context. + + Attributes: + data: The chunk payload. For intermediate chunks this is arbitrary + partial data; for the terminal chunk it must be a ``dict`` + compatible with the tool's ``output_schema``. Must be + JSON-serializable so the chunk round-trips inside a + :class:`~chainweaver.events.FlowEvent`. + is_final: ``True`` for the single terminal chunk that carries the + assembled output; ``False`` for intermediate chunks. + """ + + model_config = ConfigDict(frozen=True) + + data: Any + is_final: bool = False + + class Tool: """A named, schema-validated callable unit of work. @@ -632,6 +658,106 @@ def _flow_fn(validated_input: BaseModel) -> dict[str, Any]: return wrapped +class StreamingTool(Tool): + """A :class:`Tool` that streams its output as a sequence of chunks (#320). + + Where a plain tool returns a fully-collected ``dict``, a streaming tool + yields :class:`ToolChunk` objects as they are produced — LLM tokens, A2A + streaming events, SSE deltas — so downstream consumers (a UI, a TTS step) + can start working before the whole output is ready. + + The streaming function is an ``async`` generator taking the validated + input model and yielding :class:`ToolChunk`; it must end with exactly one + ``is_final=True`` chunk whose ``data`` is the assembled output ``dict``. + + Backwards compatibility: a ``StreamingTool`` is still a ``Tool``. When run + through a non-streaming path (``run`` / ``run_async`` / the sync executor, + or ``execute_flow_async`` without streaming consumption) it transparently + drains the stream and returns the terminal chunk's assembled output, so it + behaves exactly like an ordinary tool. Only + :meth:`~chainweaver.executor.FlowExecutor.stream_flow_async` surfaces the + intermediate chunks as ``step_chunk`` events. + + Example:: + + async def _tokens(inp: Query) -> AsyncIterator[ToolChunk]: + text = "" + for token in ("hel", "lo"): + text += token + yield ToolChunk(data={"delta": token}) + yield ToolChunk(data={"text": text}, is_final=True) + + tool = StreamingTool( + name="generate", + description="Stream a completion.", + input_schema=Query, + output_schema=Completion, + stream_fn=_tokens, + ) + """ + + def __init__( + self, + *, + name: str, + description: str, + input_schema: type[BaseModel], + output_schema: type[BaseModel], + stream_fn: Callable[[Any], AsyncIterator[ToolChunk]], + timeout_seconds: float | None = None, + max_output_size: int | None = None, + schema_version: str = "0.0.0", + cacheable: bool | None = None, + safety: ToolSafetyContract | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + self.stream_fn = stream_fn + + async def _drain(validated_input: BaseModel) -> dict[str, Any]: + # Non-streaming dispatch path: consume the stream and return the + # terminal chunk's assembled output so the tool behaves like any + # other tool under ``run`` / ``run_async`` / the sync executor. + final: dict[str, Any] | None = None + async for chunk in stream_fn(validated_input): + if chunk.is_final: + final = chunk.data + if final is None: + raise ToolDefinitionError( + name, "streaming tool produced no terminal (is_final=True) chunk." + ) + return final + + super().__init__( + name=name, + description=description, + input_schema=input_schema, + output_schema=output_schema, + fn=_drain, + timeout_seconds=timeout_seconds, + max_output_size=max_output_size, + schema_version=schema_version, + cacheable=cacheable, + safety=safety, + metadata=metadata, + ) + + async def run_streaming(self, raw_inputs: dict[str, Any]) -> AsyncIterator[ToolChunk]: + """Validate *raw_inputs* and yield the tool's :class:`ToolChunk` stream. + + Intermediate chunks are yielded verbatim; the terminal + (``is_final=True``) chunk's ``data`` is validated against the tool's + ``output_schema`` (and size cap) before being yielded, so a streaming + tool can never emit an invalid assembled output. + """ + validated_input = self.input_schema.model_validate(raw_inputs) + async for chunk in self.stream_fn(validated_input): + if chunk.is_final: + validated = self._validate_output(chunk.data) + yield ToolChunk(data=validated, is_final=True) + else: + yield chunk + + def _terminal_step(flow: Flow | DAGFlow) -> FlowStep: """Return the sole terminal step of *flow* for output-schema derivation. diff --git a/tests/fixtures/public_api.json b/tests/fixtures/public_api.json index b56565c..35cf05b 100644 --- a/tests/fixtures/public_api.json +++ b/tests/fixtures/public_api.json @@ -146,15 +146,18 @@ "StabilityLevel", "StepCache", "StepCacheKey", + "StepChunkContext", "StepDiff", "StepEndContext", "StepPlan", "StepRecord", "StepStartContext", + "StreamingTool", "StructuredLLMFn", "Suggestion", "Tool", "ToolChain", + "ToolChunk", "ToolDefinitionError", "ToolDescriptionProposal", "ToolNotFoundError", @@ -835,11 +838,12 @@ "FlowEvent": { "kind": "pydantic-model", "model_fields": { + "chunk": "chainweaver.tools.ToolChunk | NoneType", "flow_name": "str", "flow_version": "str | NoneType", "initial_input": "dict[str, Any] | NoneType", "inputs": "dict[str, Any] | NoneType", - "kind": "Literal['flow_start', 'step_start', 'step_end', 'flow_end']", + "kind": "Literal['flow_start', 'step_start', 'step_chunk', 'step_end', 'flow_end']", "result": "chainweaver.executor.ExecutionResult | NoneType", "step_index": "int | NoneType", "step_record": "chainweaver.executor.StepRecord | NoneType", @@ -1445,6 +1449,18 @@ "module": "chainweaver.cache", "qualname": "StepCacheKey" }, + "StepChunkContext": { + "kind": "pydantic-model", + "model_fields": { + "chunk": "chainweaver.tools.ToolChunk", + "flow_name": "str", + "step_index": "int", + "tool_name": "str", + "trace_id": "str" + }, + "module": "chainweaver.middleware", + "qualname": "StepChunkContext" + }, "StepDiff": { "kind": "pydantic-model", "model_fields": { @@ -1522,6 +1538,12 @@ "module": "chainweaver.middleware", "qualname": "StepStartContext" }, + "StreamingTool": { + "kind": "class", + "module": "chainweaver.tools", + "qualname": "StreamingTool", + "signature": "(*, name: str, description: str, input_schema: type[BaseModel], output_schema: type[BaseModel], stream_fn: Callable[[Any], AsyncIterator[ToolChunk]], timeout_seconds: float | None = None, max_output_size: int | None = None, schema_version: str = '0.0.0', cacheable: bool | None = None, safety: ToolSafetyContract | None = None, metadata: dict[str, Any] | None = None) -> None" + }, "StructuredLLMFn": { "kind": "class", "module": "chainweaver.proposals", @@ -1551,6 +1573,15 @@ "module": "builtins", "qualname": "tuple" }, + "ToolChunk": { + "kind": "pydantic-model", + "model_fields": { + "data": "Any", + "is_final": "bool" + }, + "module": "chainweaver.tools", + "qualname": "ToolChunk" + }, "ToolDefinitionError": { "kind": "class", "module": "chainweaver.exceptions", diff --git a/tests/test_streaming_tools.py b/tests/test_streaming_tools.py new file mode 100644 index 0000000..9b12620 --- /dev/null +++ b/tests/test_streaming_tools.py @@ -0,0 +1,203 @@ +"""Tests for StreamingTool / ToolChunk and step_chunk propagation (issue #320).""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from pydantic import BaseModel + +from chainweaver import ( + Flow, + FlowExecutor, + FlowRegistry, + FlowStep, + StreamingTool, + Tool, + ToolChunk, +) +from chainweaver.events import FlowEvent + + +class _Query(BaseModel): + prompt: str + + +class _Completion(BaseModel): + text: str + + +async def _token_stream(inp: _Query) -> AsyncIterator[ToolChunk]: + text = "" + for token in ("hel", "lo", "!"): + text += token + yield ToolChunk(data={"delta": token}) + yield ToolChunk(data={"text": text}, is_final=True) + + +def _streaming_executor() -> FlowExecutor: + registry = FlowRegistry() + registry.register_flow( + Flow( + name="gen_flow", + version="1.0.0", + description="", + steps=[FlowStep(tool_name="generate", input_mapping={"prompt": "prompt"})], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + StreamingTool( + name="generate", + description="Stream a completion.", + input_schema=_Query, + output_schema=_Completion, + stream_fn=_token_stream, + ) + ) + return ex + + +# -------------------------------------------------------------------------- +# Backward compatibility: streaming tools work on non-streaming paths +# -------------------------------------------------------------------------- + + +def test_streaming_tool_is_a_tool() -> None: + tool = StreamingTool( + name="generate", + description="", + input_schema=_Query, + output_schema=_Completion, + stream_fn=_token_stream, + ) + assert isinstance(tool, Tool) + + +async def test_streaming_tool_run_async_drains_to_final_output() -> None: + tool = StreamingTool( + name="generate", + description="", + input_schema=_Query, + output_schema=_Completion, + stream_fn=_token_stream, + ) + out = await tool.run_async({"prompt": "hi"}) + assert out == {"text": "hello!"} + + +def test_streaming_tool_runs_on_sync_executor() -> None: + ex = _streaming_executor() + result = ex.execute_flow("gen_flow", {"prompt": "hi"}) + assert result.success is True + assert result.final_output is not None + assert result.final_output["text"] == "hello!" + + +async def test_streaming_tool_runs_on_async_executor_without_streaming() -> None: + ex = _streaming_executor() + result = await ex.execute_flow_async("gen_flow", {"prompt": "hi"}) + assert result.success is True + assert result.final_output is not None + assert result.final_output["text"] == "hello!" + + +# -------------------------------------------------------------------------- +# stream_flow_async surfaces step_chunk events +# -------------------------------------------------------------------------- + + +async def test_stream_flow_async_emits_step_chunks() -> None: + ex = _streaming_executor() + kinds: list[str] = [] + deltas: list[str] = [] + async for event in ex.stream_flow_async("gen_flow", {"prompt": "hi"}): + kinds.append(event.kind) + if event.kind == "step_chunk": + assert event.chunk is not None + if not event.chunk.is_final: + deltas.append(event.chunk.data["delta"]) + # Chunks are interleaved between step_start and step_end. + assert kinds[0] == "flow_start" + assert kinds[-1] == "flow_end" + assert "step_start" in kinds + assert kinds.count("step_chunk") == 4 # 3 deltas + 1 final + start = kinds.index("step_start") + end = kinds.index("step_end") + assert all(start < i < end for i, k in enumerate(kinds) if k == "step_chunk") + assert deltas == ["hel", "lo", "!"] + + +async def test_step_chunk_event_is_json_serializable() -> None: + ex = _streaming_executor() + async for event in ex.stream_flow_async("gen_flow", {"prompt": "hi"}): + if event.kind == "step_chunk": + restored = FlowEvent.model_validate_json(event.model_dump_json()) + assert restored.chunk is not None + assert restored.chunk.data == event.chunk.data # type: ignore[union-attr] + + +async def test_non_streaming_tool_emits_no_step_chunks() -> None: + def _double(inp: _DoubleIn) -> dict[str, Any]: + return {"value": inp.n * 2} + + class _DoubleIn(BaseModel): + n: int + + class _DoubleOut(BaseModel): + value: int + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="plain", + version="1.0.0", + description="", + steps=[FlowStep(tool_name="double", input_mapping={"n": "n"})], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + Tool( + name="double", + description="", + input_schema=_DoubleIn, + output_schema=_DoubleOut, + fn=_double, + ) + ) + kinds = [e.kind async for e in ex.stream_flow_async("plain", {"n": 3})] + assert "step_chunk" not in kinds + assert kinds == ["flow_start", "step_start", "step_end", "flow_end"] + + +# -------------------------------------------------------------------------- +# Failure paths +# -------------------------------------------------------------------------- + + +async def test_streaming_tool_without_terminal_chunk_fails_step() -> None: + async def _no_final(inp: _Query) -> AsyncIterator[ToolChunk]: + yield ToolChunk(data={"delta": "x"}) # never sets is_final + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="bad", + version="1.0.0", + description="", + steps=[FlowStep(tool_name="g", input_mapping={"prompt": "prompt"})], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + StreamingTool( + name="g", + description="", + input_schema=_Query, + output_schema=_Completion, + stream_fn=_no_final, + ) + ) + result = await ex.execute_flow_async("bad", {"prompt": "hi"}) + assert result.success is False From 5af0158001c0fc134b7b68b7faaf7e251e932d4f Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 22 Jun 2026 20:52:59 +0000 Subject: [PATCH 5/6] fix(streaming): enforce terminal-chunk contract, stream timeout; fix docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Copilot review on PR #471: - StreamingTool.run_streaming / _drain now enforce 'exactly one terminal is_final chunk, and it must be last' — a chunk after the terminal chunk or a missing terminal raises ToolDefinitionError (was: silently kept the last). - _run_streaming_step_async bounds the whole stream with the tool's timeout_seconds via asyncio.wait_for, surfacing ToolTimeoutError (was: an async streaming step could hang indefinitely). Document that cache bypass and retry-skip for streaming steps are intentional (a partial stream cannot be replayed/memoised). - Correct the stream_flow_async docstring: a terminal flow_end carrying the partial IS emitted before FlowCancelledError is raised. Co-Authored-By: Claude Opus 4.8 Claude-Session: https://claude.ai/code/session_011U2ZKg2GMec7iAUDt2zezq --- chainweaver/executor.py | 40 ++++++++++++++++++++++---- chainweaver/tools.py | 38 +++++++++++++++++++++++-- tests/test_streaming_tools.py | 53 +++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 8 deletions(-) diff --git a/chainweaver/executor.py b/chainweaver/executor.py index 0ac66de..ffb882c 100644 --- a/chainweaver/executor.py +++ b/chainweaver/executor.py @@ -2799,11 +2799,22 @@ async def _run_streaming_step_async( :class:`StepRecord` built from the terminal chunk's assembled output. Tool failures route through the same async ``on_error`` machinery as a normal step (fail / skip / fallback). + + The tool's ``timeout_seconds`` bounds the *whole* stream via + :func:`asyncio.wait_for` (a stalled stream cannot hang the loop + indefinitely); a breach surfaces as + :class:`~chainweaver.exceptions.ToolTimeoutError`, matching the + non-streaming path. By design the step cache is bypassed and + ``step.retry`` is not applied to a streaming step: a partially consumed + stream (whose chunks have already been emitted downstream) cannot be + safely replayed or memoised. Those semantics are intentional, not an + oversight. """ tool_attempts[0] += 1 retry_errors: list[str] = [] - try: - final_output: dict[str, Any] | None = None + + async def _consume() -> dict[str, Any] | None: + collected: dict[str, Any] | None = None async for chunk in tool.run_streaming(inputs): self._fire_step_chunk( StepChunkContext( @@ -2815,8 +2826,22 @@ async def _run_streaming_step_async( ) ) if chunk.is_final: - final_output = chunk.data + collected = chunk.data + return collected + + try: + if tool.timeout_seconds is not None: + try: + final_output: dict[str, Any] | None = await asyncio.wait_for( + _consume(), timeout=tool.timeout_seconds + ) + except asyncio.TimeoutError as exc: + raise ToolTimeoutError(tool.name, tool.timeout_seconds) from exc + else: + final_output = await _consume() if final_output is None: + # ``run_streaming`` already enforces a terminal chunk; this guards + # the type and any future stream source that bypasses it. raise ToolDefinitionError( tool.name, "streaming tool produced no terminal (is_final=True) chunk." ) @@ -3254,9 +3279,12 @@ async def stream_flow_async( next step boundary by raising :class:`~chainweaver.exceptions.FlowCancelledError` from the iterator — its :attr:`~chainweaver.exceptions.FlowCancelledError.result` carries - the partial run. No ``flow_end`` is emitted on cancellation (execution - raised before the flow-end hook). If the consumer stops iterating - early, the backing :meth:`execute_flow_async` task is cancelled. + the partial run. A terminal ``flow_end`` event carrying that same + partial result is emitted *before* the error is raised (the + cancellation path builds the partial via ``_make_result``, which fires + the flow-end hook), so a consumer sees ``flow_end`` and then the raised + ``FlowCancelledError``. If the consumer stops iterating early, the + backing :meth:`execute_flow_async` task is cancelled. The async lane's feature support applies (issue #388): composed sub-flows, the step cache, and checkpoints work; conditional branching diff --git a/chainweaver/tools.py b/chainweaver/tools.py index 42fe9c0..04be3c3 100644 --- a/chainweaver/tools.py +++ b/chainweaver/tools.py @@ -674,10 +674,16 @@ class StreamingTool(Tool): through a non-streaming path (``run`` / ``run_async`` / the sync executor, or ``execute_flow_async`` without streaming consumption) it transparently drains the stream and returns the terminal chunk's assembled output, so it - behaves exactly like an ordinary tool. Only + produces the same final output as an ordinary tool. Only :meth:`~chainweaver.executor.FlowExecutor.stream_flow_async` surfaces the intermediate chunks as ``step_chunk`` events. + Streaming-step semantics (intentional, see + :meth:`~chainweaver.executor.FlowExecutor.execute_flow_async`): the tool's + ``timeout_seconds`` bounds the whole stream, but the executor's step cache + is bypassed and step-level ``retry`` is not applied to a streaming step — + a partially emitted stream cannot be safely memoised or replayed. + Example:: async def _tokens(inp: Query) -> AsyncIterator[ToolChunk]: @@ -717,14 +723,25 @@ async def _drain(validated_input: BaseModel) -> dict[str, Any]: # Non-streaming dispatch path: consume the stream and return the # terminal chunk's assembled output so the tool behaves like any # other tool under ``run`` / ``run_async`` / the sync executor. + # Enforce the streaming contract: exactly one terminal chunk, and + # it must be the last one emitted. final: dict[str, Any] | None = None + seen_final = False async for chunk in stream_fn(validated_input): + if seen_final: + raise ToolDefinitionError( + name, + "streaming tool yielded a chunk after its terminal " + "(is_final=True) chunk; the terminal chunk must be last.", + ) if chunk.is_final: + seen_final = True final = chunk.data - if final is None: + if not seen_final: raise ToolDefinitionError( name, "streaming tool produced no terminal (is_final=True) chunk." ) + assert final is not None return final super().__init__( @@ -748,14 +765,31 @@ async def run_streaming(self, raw_inputs: dict[str, Any]) -> AsyncIterator[ToolC (``is_final=True``) chunk's ``data`` is validated against the tool's ``output_schema`` (and size cap) before being yielded, so a streaming tool can never emit an invalid assembled output. + + Enforces the streaming contract: exactly one terminal chunk, and it + must be the last one emitted. A chunk after the terminal chunk, or no + terminal chunk at all, raises + :class:`~chainweaver.exceptions.ToolDefinitionError`. """ validated_input = self.input_schema.model_validate(raw_inputs) + seen_final = False async for chunk in self.stream_fn(validated_input): + if seen_final: + raise ToolDefinitionError( + self.name, + "streaming tool yielded a chunk after its terminal " + "(is_final=True) chunk; the terminal chunk must be last.", + ) if chunk.is_final: + seen_final = True validated = self._validate_output(chunk.data) yield ToolChunk(data=validated, is_final=True) else: yield chunk + if not seen_final: + raise ToolDefinitionError( + self.name, "streaming tool produced no terminal (is_final=True) chunk." + ) def _terminal_step(flow: Flow | DAGFlow) -> FlowStep: diff --git a/tests/test_streaming_tools.py b/tests/test_streaming_tools.py index 9b12620..d9f6c24 100644 --- a/tests/test_streaming_tools.py +++ b/tests/test_streaming_tools.py @@ -2,9 +2,11 @@ from __future__ import annotations +import asyncio from collections.abc import AsyncIterator from typing import Any +import pytest from pydantic import BaseModel from chainweaver import ( @@ -17,6 +19,7 @@ ToolChunk, ) from chainweaver.events import FlowEvent +from chainweaver.exceptions import ToolDefinitionError class _Query(BaseModel): @@ -201,3 +204,53 @@ async def _no_final(inp: _Query) -> AsyncIterator[ToolChunk]: ) result = await ex.execute_flow_async("bad", {"prompt": "hi"}) assert result.success is False + + +async def test_chunk_after_terminal_is_rejected() -> None: + async def _extra_after_final(inp: _Query) -> AsyncIterator[ToolChunk]: + yield ToolChunk(data={"text": "done"}, is_final=True) + yield ToolChunk(data={"delta": "oops"}) # illegal: chunk after terminal + + tool = StreamingTool( + name="g", + description="", + input_schema=_Query, + output_schema=_Completion, + stream_fn=_extra_after_final, + ) + # The contract is enforced on both the streaming and the drained paths. + with pytest.raises(ToolDefinitionError, match="after its terminal"): + async for _ in tool.run_streaming({"prompt": "hi"}): + pass + with pytest.raises(ToolDefinitionError, match="after its terminal"): + await tool.run_async({"prompt": "hi"}) + + +async def test_streaming_tool_timeout_is_enforced() -> None: + async def _slow_stream(inp: _Query) -> AsyncIterator[ToolChunk]: + await asyncio.sleep(0.3) + yield ToolChunk(data={"text": "late"}, is_final=True) + + registry = FlowRegistry() + registry.register_flow( + Flow( + name="slow", + version="1.0.0", + description="", + steps=[FlowStep(tool_name="g", input_mapping={"prompt": "prompt"})], + ) + ) + ex = FlowExecutor(registry=registry) + ex.register_tool( + StreamingTool( + name="g", + description="", + input_schema=_Query, + output_schema=_Completion, + stream_fn=_slow_stream, + timeout_seconds=0.05, + ) + ) + result = await ex.execute_flow_async("slow", {"prompt": "hi"}) + assert result.success is False + assert result.execution_log[0].error_type == "ToolTimeoutError" From 4b575748fcd68adeca43cabd49e522a2444bcddc Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 24 Jun 2026 12:14:20 +0000 Subject: [PATCH 6/6] docs(events): correct stale stream cancellation docstrings Update the FlowEvent module docstring's Cancellation section to describe the implemented stream_flow_async (#389) and its FlowCancelledError semantics (terminal flow_end emitted before the error), replacing the never-shipped issue-#80 / asyncio.CancelledError plan and noting that sync stream_flow now honors cancel_token/deadline at step boundaries. Add the missing "step_chunk" variant to the FlowEvent.kind attribute docstring so it matches FlowEventKind and the variant table. Addresses two audit findings from the PR review cycle; docstring-only, no behavior change. Claude-Session: https://claude.ai/code/session_01NeMpgsSgvnq4MdnpyLuwcA --- chainweaver/events.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/chainweaver/events.py b/chainweaver/events.py index 5920ff2..7418f54 100644 --- a/chainweaver/events.py +++ b/chainweaver/events.py @@ -29,11 +29,18 @@ ------------ The sync :meth:`~chainweaver.executor.FlowExecutor.stream_flow` -generator does **not** cancel in-flight execution when the consumer -stops iterating: a background worker thread drives the flow to -completion, then exits. Document this loudly in any UI you build. -The async variant (gated on issue #80) is expected to support -:class:`asyncio.CancelledError`-driven cancellation cleanly. +generator drives the flow on a background worker thread. A +``cancel_token`` / ``deadline`` (issue #389) is honored at step +boundaries, but the worker still finishes the in-flight tool before +the stream ends — abandoning the generator does not abort a running +step. Document this loudly in any UI you build. + +The async-native :meth:`~chainweaver.executor.FlowExecutor.stream_flow_async` +(issue #389) cancels cooperatively: a ``cancel_token`` or ``deadline`` +ends the stream at the next step boundary by raising +:class:`~chainweaver.exceptions.FlowCancelledError`, whose ``result`` +carries the partial run. A terminal ``flow_end`` event is emitted +before the error is raised. """ from __future__ import annotations @@ -79,7 +86,7 @@ class FlowEvent(BaseModel): Attributes: kind: One of ``"flow_start"`` / ``"step_start"`` / - ``"step_end"`` / ``"flow_end"``. + ``"step_chunk"`` / ``"step_end"`` / ``"flow_end"``. flow_name: Name of the flow being executed. trace_id: UUID4 hex string correlating every event in this stream with logs and middleware contexts.