diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index 7c41b5da7..dafe598b7 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -143,6 +143,107 @@ await g.ainvoke({...}, context=Context(user_id="alice")) Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the Activity boundary. +## Streaming + +When `streaming_topic` is set on `LangGraphPlugin`, calls to `langgraph.config.get_stream_writer()` inside a node publish to the named topic on the workflow's [`WorkflowStream`](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Activity-side nodes publish via `WorkflowStreamClient` (a signal carrying batched items, controlled by `streaming_batch_interval`); workflow-side nodes publish synchronously to the in-workflow stream (no signal). External subscribers consume the stream with `WorkflowStreamClient.create(...).topic(...).subscribe(...)`. + +The workflow **must** construct `WorkflowStream()` in its `@workflow.init` (i.e. `__init__`) + +```python +from datetime import timedelta +from typing import Any + +from langgraph.config import get_stream_writer +from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph import LangGraphPlugin, graph +from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient +from temporalio.worker import Worker + + +class State(TypedDict): + value: str + + +async def token_node(state: State) -> dict[str, str]: + writer = get_stream_writer() + for token in ["hello", " ", "world"]: + writer({"token": token}) + writer({"done": True}) + return {"value": "hello world"} + + +@workflow.defn +class StreamingWorkflow: + def __init__(self) -> None: + # Required when streaming_topic is set on the plugin. + _ = WorkflowStream() + self.app = graph("streaming").compile() + + @workflow.run + async def run(self) -> str: + result = await self.app.ainvoke({"value": ""}) + return result["value"] + + +async def main(client: Client) -> None: + g = StateGraph(State) + g.add_node("token_node", token_node, metadata={"execute_in": "activity"}) + g.add_edge(START, "token_node") + + async with Worker( + client, + task_queue="streaming-tq", + workflows=[StreamingWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"streaming": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + streaming_topic="tokens", + ) + ], + ): + handle = await client.start_workflow( + StreamingWorkflow.run, id="streaming-wf", task_queue="streaming-tq" + ) + + ws_client = WorkflowStreamClient.create(client, handle.id) + async for item in ws_client.topic("tokens", type=dict).subscribe(from_offset=0): + print(item.data) + if item.data.get("done"): + break + + print(await handle.result()) +``` + +### What's covered, and what isn't + +`streaming_topic` wires up exactly **one** LangGraph stream mode: `stream_mode="custom"`, i.e. values written through `get_stream_writer()`. The other modes — `"messages"`, `"values"`, `"updates"`, `"debug"` — are **not** captured by `streaming_topic`. They aren't produced by node-side writers; LangGraph's orchestrator emits them as it walks the graph. The documented pattern is to **bridge `astream()` in the workflow** and republish each yielded chunk to a `WorkflowStream` topic yourself: + +```python +@workflow.defn +class AstreamBridge: + def __init__(self) -> None: + self.stream = WorkflowStream() + self.app = graph("g").compile() + + @workflow.run + async def run(self) -> None: + topic = self.stream.topic("astream") + async for chunk in self.app.astream({...}, stream_mode="messages"): + topic.publish(chunk) + topic.publish({"done": True}) +``` + +### Retry semantics + +Streaming has **at-least-once** delivery per activity attempt. When an activity-wrapped node retries (transient failure, worker crash, etc.), the user function re-runs from scratch and re-publishes its writes — earlier publishes from the failed attempt are not rolled back. Subscribers should be ready to see duplicates and recover idempotently (e.g. dedupe on a sequence id you include in each chunk, or treat the stream as advisory and rely on the workflow's final result for state). + ## Tracing We recommend the [Temporal LangSmith Plugin](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/langsmith) to trace your LangGraph Workflows and Activities. diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index c12d459a6..e9aaf5605 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -19,7 +19,7 @@ __all__ = [ "LangGraphPlugin", - "entrypoint", "cache", + "entrypoint", "graph", ] diff --git a/temporalio/contrib/langgraph/_activity.py b/temporalio/contrib/langgraph/_activity.py index f1d66a200..5891eb81b 100644 --- a/temporalio/contrib/langgraph/_activity.py +++ b/temporalio/contrib/langgraph/_activity.py @@ -1,7 +1,9 @@ """Activity wrappers for executing LangGraph nodes and tasks.""" +import asyncio from collections.abc import Awaitable from dataclasses import dataclass +from datetime import timedelta from inspect import iscoroutinefunction, signature from typing import Any, Callable @@ -19,6 +21,7 @@ cache_lookup, cache_put, ) +from temporalio.contrib.workflow_streams import WorkflowStreamClient # Per-run dedupe so we only warn once when a user passes a Store via # graph.compile(store=...) / @entrypoint(store=...). Cleared by @@ -51,28 +54,53 @@ class ActivityOutput: def wrap_activity( func: Callable, + *, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), ) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: """Wrap a function as a Temporal activity that handles LangGraph config and interrupts.""" - # Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks - # don't and instead reach for Runtime via get_runtime(). We re-inject the - # reconstructed Runtime only when the user function asks. - accepts_runtime = "runtime" in signature(func).parameters async def wrapper(input: ActivityInput) -> ActivityOutput: - runtime = set_langgraph_config(input.langgraph_config) - kwargs = dict(input.kwargs) - if accepts_runtime: - kwargs["runtime"] = runtime - try: - if iscoroutinefunction(func): - result = await func(*input.args, **kwargs) - else: - result = func(*input.args, **kwargs) - if isinstance(result, Command): - return ActivityOutput(langgraph_command=result) - return ActivityOutput(result=result) - except GraphInterrupt as e: - return ActivityOutput(langgraph_interrupts=e.args[0]) + async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput: + # Sync funcs run on a thread (so the loop keeps flushing the + # stream client mid-execution); marshal writer calls back to + # the loop thread because the client's flush event is an + # asyncio.Event and isn't safe to set off-thread. + effective_writer = stream_writer + if not iscoroutinefunction(func) and stream_writer is not None: + loop = asyncio.get_running_loop() + inner_writer = stream_writer + + def thread_safe_writer(value: Any) -> None: + loop.call_soon_threadsafe(inner_writer, value) + + effective_writer = thread_safe_writer + + runtime = set_langgraph_config( + input.langgraph_config, stream_writer=effective_writer + ) + kwargs = dict(input.kwargs) + if "runtime" in signature(func).parameters: + kwargs["runtime"] = runtime + + try: + if iscoroutinefunction(func): + result = await func(*input.args, **kwargs) + else: + result = await asyncio.to_thread(func, *input.args, **kwargs) + if isinstance(result, Command): + return ActivityOutput(langgraph_command=result) + return ActivityOutput(result=result) + except GraphInterrupt as e: + return ActivityOutput(langgraph_interrupts=e.args[0]) + + if streaming_topic is None: + return await run(stream_writer=None) + async with WorkflowStreamClient.from_within_activity( + batch_interval=streaming_batch_interval, + ) as client: + topic = client.topic(streaming_topic) + return await run(stream_writer=topic.publish) return wrapper diff --git a/temporalio/contrib/langgraph/_interceptor.py b/temporalio/contrib/langgraph/_interceptor.py index fd583c052..f68d9d45d 100644 --- a/temporalio/contrib/langgraph/_interceptor.py +++ b/temporalio/contrib/langgraph/_interceptor.py @@ -11,6 +11,7 @@ from temporalio import workflow from temporalio.contrib.langgraph._activity import clear_store_warning +from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL from temporalio.worker import ( ExecuteWorkflowInput, Interceptor, @@ -30,10 +31,12 @@ def __init__( self, graphs: dict[str, StateGraph[Any, Any, Any, Any]], entrypoints: dict[str, Pregel[Any, Any, Any, Any]], + streaming_topic: str | None = None, ) -> None: """Initialize with the graphs and entrypoints to scope to each workflow run.""" self._graphs = graphs self._entrypoints = entrypoints + self._streaming_topic = streaming_topic def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput @@ -41,6 +44,7 @@ def workflow_interceptor_class( """Return the inbound interceptor class used to scope graphs per run.""" graphs = self._graphs entrypoints = self._entrypoints + streaming_topic = self._streaming_topic class Inbound(WorkflowInboundInterceptor): def init(self, outbound: WorkflowOutboundInterceptor) -> None: @@ -50,6 +54,18 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None: super().init(outbound) async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + if ( + streaming_topic is not None + and workflow.get_signal_handler(_PUBLISH_SIGNAL) is None + ): + raise RuntimeError( + f"LangGraphPlugin was configured with " + f"streaming_topic={streaming_topic!r}, but workflow " + f"{workflow.info().workflow_type!r} did not register a " + f"WorkflowStream. Construct WorkflowStream() in the " + f"workflow's @workflow.init (i.e. __init__) method so " + f"streaming activities can publish to it." + ) try: return await self.next.execute_workflow(input) finally: diff --git a/temporalio/contrib/langgraph/_langgraph_config.py b/temporalio/contrib/langgraph/_langgraph_config.py index 4cc529477..90c6c810d 100644 --- a/temporalio/contrib/langgraph/_langgraph_config.py +++ b/temporalio/contrib/langgraph/_langgraph_config.py @@ -3,7 +3,7 @@ # pyright: reportMissingTypeStubs=false import dataclasses -from typing import Any +from typing import Any, Callable from langchain_core.runnables.config import var_child_runnable_config from langgraph._internal._constants import ( @@ -93,7 +93,11 @@ def get_langgraph_config() -> dict[str, Any]: } -def set_langgraph_config(config: dict[str, Any]) -> Runtime: +def set_langgraph_config( + config: dict[str, Any], + *, + stream_writer: Callable[[Any], None] | None = None, +) -> Runtime: """Restore a LangGraph runnable config from a serialized dict. Returns the reconstructed Runtime so callers can re-inject it into the @@ -112,7 +116,7 @@ def get_null_resume(consume: bool = False) -> Any: execution_info_dict = config.get("execution_info") runtime = Runtime( context=config.get("context"), - stream_writer=lambda _: None, + stream_writer=stream_writer or (lambda _: None), previous=config.get("previous"), execution_info=( ExecutionInfo(**execution_info_dict) if execution_info_dict else None diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index a624f62a7..a1320d1a8 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -8,6 +8,7 @@ import sys import warnings from dataclasses import replace +from datetime import timedelta from typing import Any, Callable from langgraph._internal._runnable import RunnableCallable @@ -26,6 +27,7 @@ set_task_cache, task_id, ) +from temporalio.contrib.langgraph._workflow import wrap_workflow from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -46,6 +48,49 @@ class LangGraphPlugin(SimplePlugin): and tasks as Temporal Activities, giving your AI agent workflows durable execution, automatic retries, and timeouts. It supports both the LangGraph Graph API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). + + Args: + graphs: Graph API graphs to make available to workflows, keyed by name. + Workflows retrieve them with :func:`graph` and call + ``.compile()`` to get a runnable. Each node's ``metadata`` must + include ``execute_in`` (``"activity"`` or ``"workflow"``) and + may include any kwarg accepted by + :func:`workflow.execute_activity` (e.g. ``start_to_close_timeout``, + ``retry_policy``). + entrypoints: Functional API entrypoints to make available to + workflows, keyed by name. Workflows retrieve them with + :func:`entrypoint`. + tasks: Functional API ``@task`` functions to wrap as Temporal + Activities. + activity_options: Per-task activity options for the Functional + API, keyed by task function name. Each entry must include + ``execute_in`` and may include any + :func:`workflow.execute_activity` kwarg. Used because LangGraph's + Functional API has no per-task ``metadata`` channel. + default_activity_options: Activity options applied to every + activity-bound node and task, overridable per-node (Graph API + ``metadata``) or per-task (``activity_options[name]``). + streaming_topic: When set, ``langgraph.config.get_stream_writer()`` + inside a node publishes to this topic on the workflow's + :class:`WorkflowStream`. The workflow must construct + ``WorkflowStream()`` in its ``@workflow.init`` (the plugin's + interceptor verifies this on workflow start). Nodes with + ``execute_in='activity'`` publish through + :class:`WorkflowStreamClient` (signal); nodes with + ``execute_in='workflow'`` publish synchronously to the + in-workflow stream (no signal). + streaming_batch_interval: How often the activity-side stream + client flushes buffered publishes into a single + ``__temporal_workflow_stream_publish`` signal. Has no effect + on workflow-side nodes (their publishes are synchronous + in-memory log appends). Lower values reduce streaming + latency at the cost of more signals (more workflow history + events); higher values amortize signal cost but make + chunks arrive in larger bursts. Default 100ms suits + interactive token streaming; raise to 250–1000ms for + non-interactive aggregation, lower toward 10–50ms only if + you've measured the latency need and accept the history + cost. """ def __init__( @@ -58,8 +103,15 @@ def __init__( # TODO: Remove activity_options when we have support for @task(metadata=...) activity_options: dict[str, dict[str, Any]] | None = None, default_activity_options: dict[str, Any] | None = None, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), ): - """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" + """Initialize the LangGraph plugin with graphs, entrypoints, and tasks. + + .. warning:: + Streaming support is experimental and may change in + future versions. + """ if sys.version_info < (3, 11): warnings.warn( # type: ignore[reportUnreachable] "LangGraphPlugin requires Python >= 3.11 for full async support. " @@ -79,6 +131,8 @@ def __init__( ) self.activities: list = [] + self._streaming_topic = streaming_topic + self._streaming_batch_interval = streaming_batch_interval # Graph API: Wrap graph nodes as Temporal Activities. if graphs: @@ -95,7 +149,7 @@ def __init__( runnable = node.runnable if not isinstance(runnable, RunnableCallable): raise ValueError(f"Node {node_name} must be a RunnableCallable") - user_func = runnable.afunc or runnable.func + user_func = runnable.func or runnable.afunc if user_func is None: raise ValueError(f"Node {node_name} must have a function") # Keep 'config' (for metadata/tags) and 'runtime' (for @@ -183,7 +237,11 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "langchain.LangGraphPlugin", activities=self.activities, workflow_runner=workflow_runner, - interceptors=[LangGraphInterceptor(graphs or {}, entrypoints or {})], + interceptors=[ + LangGraphInterceptor( + graphs or {}, entrypoints or {}, streaming_topic=streaming_topic + ) + ], ) def execute( @@ -197,11 +255,16 @@ def execute( execute_in = opts.pop("execute_in") if execute_in == "activity": - a = activity.defn(name=activity_name)(wrap_activity(func)) + wrapped = wrap_activity( + func, + streaming_topic=self._streaming_topic, + streaming_batch_interval=self._streaming_batch_interval, + ) + a = activity.defn(name=activity_name)(wrapped) self.activities.append(a) return wrap_execute_activity(a, task_id=task_id(func), **opts) elif execute_in == "workflow": - return func + return wrap_workflow(func, streaming_topic=self._streaming_topic) else: raise ValueError(f"Invalid execute_in value: {execute_in}") diff --git a/temporalio/contrib/langgraph/_workflow.py b/temporalio/contrib/langgraph/_workflow.py new file mode 100644 index 000000000..67bfd4f68 --- /dev/null +++ b/temporalio/contrib/langgraph/_workflow.py @@ -0,0 +1,62 @@ +"""Workflow-side wrappers for executing LangGraph nodes inline in a workflow.""" + +# pyright: reportMissingTypeStubs=false + +from __future__ import annotations + +import dataclasses +from collections.abc import Awaitable +from inspect import iscoroutinefunction +from typing import Any, Callable + +from langchain_core.runnables.config import var_child_runnable_config +from langgraph._internal._constants import CONFIG_KEY_RUNTIME + +from temporalio import workflow +from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL + + +def wrap_workflow( + func: Callable[..., Any], + *, + streaming_topic: str | None = None, +) -> Callable[..., Awaitable[Any]]: + """Wrap a function as a workflow-side LangGraph node. + + Mirrors :func:`wrap_activity`: the outer wrapper resolves a stream + writer and passes it to an inner ``run`` that invokes the user + function with the writer installed. Workflow-side nodes publish + synchronously to the in-workflow ``WorkflowStream`` (no signal + round-trip); activity-side nodes go through ``WorkflowStreamClient``. + """ + + async def wrapper(*args: Any, **kwargs: Any) -> Any: + async def run(stream_writer: Callable[[Any], None] | None) -> Any: + token = None + if stream_writer is not None: + config = var_child_runnable_config.get() or {} + configurable = dict(config.get("configurable") or {}) + runtime = configurable.get(CONFIG_KEY_RUNTIME) + if runtime is not None: + configurable[CONFIG_KEY_RUNTIME] = dataclasses.replace( + runtime, stream_writer=stream_writer + ) + token = var_child_runnable_config.set( + {**config, "configurable": configurable} + ) + try: + if iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) + finally: + if token is not None: + var_child_runnable_config.reset(token) + + if streaming_topic is None: + return await run(stream_writer=None) + publish_handler = workflow.get_signal_handler(_PUBLISH_SIGNAL) + stream = getattr(publish_handler, "__self__") + topic = stream.topic(streaming_topic) + return await run(stream_writer=topic.publish) + + return wrapper diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index f47feffee..ff00b8906 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -2,12 +2,17 @@ from typing import Any from uuid import uuid4 +import pytest +from langgraph.config import ( + get_stream_writer, # pyright: ignore[reportMissingTypeStubs] +) from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client from temporalio.contrib.langgraph import LangGraphPlugin, graph +from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient from temporalio.worker import Worker @@ -15,54 +20,164 @@ class State(TypedDict): value: str -async def node_a(state: State) -> dict[str, str]: - return {"value": state["value"] + "a"} +async def async_token_node(state: State) -> dict[str, str]: + tokens = ["a", "b", "c"] + writer = get_stream_writer() + for token in tokens: + writer({"token": token}) + writer({"done": True}) + return {"value": state["value"] + "".join(tokens)} -async def node_b(state: State) -> dict[str, str]: - return {"value": state["value"] + "b"} +def sync_token_node(state: State) -> dict[str, str]: + tokens = ["a", "b", "c"] + writer = get_stream_writer() + for token in tokens: + writer({"token": token}) + writer({"done": True}) + return {"value": state["value"] + "".join(tokens)} + + +@workflow.defn +class StreamingWorkflowStreamsWorkflow: + def __init__(self) -> None: + _ = WorkflowStream() + self.app = graph("streaming-ws").compile() + + @workflow.run + async def run(self, input: str) -> str: + result = await self.app.ainvoke({"value": input}) + return result["value"] + + +@pytest.mark.parametrize("execute_in", ["activity", "workflow"]) +@pytest.mark.parametrize( + "node", [async_token_node, sync_token_node], ids=["async", "sync"] +) +async def test_streaming_via_workflow_streams( + client: Client, execute_in: str, node: Any +): + g = StateGraph(State) + g.add_node("token_node", node, metadata={"execute_in": execute_in}) + g.add_edge(START, "token_node") + + task_queue = f"streaming-ws-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[StreamingWorkflowStreamsWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"streaming-ws": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + streaming_topic="tokens", + ) + ], + ): + handle = await client.start_workflow( + StreamingWorkflowStreamsWorkflow.run, + "", + id=f"test-streaming-ws-{uuid4()}", + task_queue=task_queue, + ) + + ws_client = WorkflowStreamClient.create(client, handle.id) + chunks: list[dict[str, Any]] = [] + async for item in ws_client.topic("tokens", type=dict).subscribe( + from_offset=0, + poll_cooldown=timedelta(0), + ): + chunks.append(item.data) + if chunks[-1].get("done"): + break + + result = await handle.result() + + assert result == "abc" + assert chunks == [ + {"token": "a"}, + {"token": "b"}, + {"token": "c"}, + {"done": True}, + ] + + +# --------------------------------------------------------------------------- +# Workflow-side publish: iterate astream() in the workflow and forward each +# chunk via self.stream.topic("astream").publish(...) so external subscribers +# see node-level progress alongside any activity-emitted tokens. +# --------------------------------------------------------------------------- @workflow.defn -class StreamingWorkflow: +class AstreamPublishWorkflow: def __init__(self) -> None: - self.app = graph("streaming").compile() + self.stream = WorkflowStream() + self.app = graph("astream-publish").compile() @workflow.run - async def run(self, input: str) -> Any: - chunks = [] + async def run(self, input: str) -> str: + topic = self.stream.topic("astream") async for chunk in self.app.astream({"value": input}): - chunks.append(chunk) - return chunks + topic.publish(chunk) + topic.publish({"done": True}) + return "done" + + +async def node_a(state: State) -> dict[str, str]: + return {"value": state["value"] + "a"} + + +async def node_b(state: State) -> dict[str, str]: + return {"value": state["value"] + "b"} -async def test_streaming(client: Client): +async def test_workflow_publishes_astream_chunks(client: Client): g = StateGraph(State) g.add_node("node_a", node_a, metadata={"execute_in": "activity"}) g.add_node("node_b", node_b, metadata={"execute_in": "activity"}) g.add_edge(START, "node_a") g.add_edge("node_a", "node_b") - task_queue = f"streaming-{uuid4()}" + task_queue = f"astream-publish-{uuid4()}" async with Worker( client, task_queue=task_queue, - workflows=[StreamingWorkflow], + workflows=[AstreamPublishWorkflow], plugins=[ LangGraphPlugin( - graphs={"streaming": g}, + graphs={"astream-publish": g}, default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, ) ], ): - chunks = await client.execute_workflow( - StreamingWorkflow.run, + handle = await client.start_workflow( + AstreamPublishWorkflow.run, "", - id=f"test-streaming-{uuid4()}", + id=f"test-astream-publish-{uuid4()}", task_queue=task_queue, ) - assert chunks == [{"node_a": {"value": "a"}}, {"node_b": {"value": "ab"}}] + ws_client = WorkflowStreamClient.create(client, handle.id) + chunks: list[dict[str, Any]] = [] + async for item in ws_client.topic("astream", type=dict).subscribe( + from_offset=0, + poll_cooldown=timedelta(0), + ): + chunks.append(item.data) + if chunks[-1].get("done"): + break + + await handle.result() + + assert chunks == [ + {"node_a": {"value": "a"}}, + {"node_b": {"value": "ab"}}, + {"done": True}, + ]