From 360ef92b796355f0e93b396d9169b265751153ac Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 09:52:57 -0700 Subject: [PATCH 01/15] first pass at langgraph streaming --- temporalio/contrib/langgraph/__init__.py | 4 +- temporalio/contrib/langgraph/_activity.py | 54 +++++-- .../contrib/langgraph/_langgraph_config.py | 14 +- .../langgraph/test_functional_runtime.py | 79 +++++++++ tests/contrib/langgraph/test_runtime.py | 66 ++++++++ tests/contrib/langgraph/test_streaming.py | 151 +++++++++++++++++- 6 files changed, 349 insertions(+), 19 deletions(-) create mode 100644 tests/contrib/langgraph/test_functional_runtime.py create mode 100644 tests/contrib/langgraph/test_runtime.py diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index c12d459a6..f7f571bcc 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -10,6 +10,7 @@ API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). """ +from temporalio.contrib.langgraph._activity import STREAM_TOPIC from temporalio.contrib.langgraph._plugin import ( LangGraphPlugin, cache, @@ -19,7 +20,8 @@ __all__ = [ "LangGraphPlugin", - "entrypoint", + "STREAM_TOPIC", "cache", + "entrypoint", "graph", ] diff --git a/temporalio/contrib/langgraph/_activity.py b/temporalio/contrib/langgraph/_activity.py index f1d66a200..8b86e7bf1 100644 --- a/temporalio/contrib/langgraph/_activity.py +++ b/temporalio/contrib/langgraph/_activity.py @@ -19,6 +19,16 @@ cache_lookup, cache_put, ) +from temporalio.contrib.workflow_streams import WorkflowStreamClient + +STREAM_TOPIC = "langgraph_stream" +"""Topic that LangGraph node stream_writer chunks are published to. + +Each chunk is encoded by the configured payload converter and delivered +to the parent workflow's :class:`WorkflowStream`. Subscribers receive +already-decoded values via ``WorkflowStreamClient.subscribe`` — +``item.data`` is the chunk, no manual decoding required. +""" # Per-run dedupe so we only warn once when a user passes a Store via # graph.compile(store=...) / @entrypoint(store=...). Cleared by @@ -59,20 +69,36 @@ def wrap_activity( accepts_runtime = "runtime" in signature(func).parameters async def wrapper(input: ActivityInput) -> ActivityOutput: - runtime = set_langgraph_config(input.langgraph_config) - kwargs = dict(input.kwargs) - if accepts_runtime: - kwargs["runtime"] = runtime - try: - if iscoroutinefunction(func): - result = await func(*input.args, **kwargs) - else: - result = func(*input.args, **kwargs) - if isinstance(result, Command): - return ActivityOutput(langgraph_command=result) - return ActivityOutput(result=result) - except GraphInterrupt as e: - return ActivityOutput(langgraph_interrupts=e.args[0]) + # Back get_stream_writer() with a WorkflowStreamClient targeting the + # owning workflow. Chunks emitted inside the node are signaled back + # to the workflow's WorkflowStream. If the node never calls + # writer(...), the buffer stays empty and the final flush is a + # no-op — no signals are sent. + client = WorkflowStreamClient.from_within_activity() + + def stream_writer(chunk: Any) -> None: + # force_flush=True wakes the flusher to send immediately instead + # of waiting for the batch_interval timer; rapid writer calls + # still coalesce into a single signal while in-flight. + client.topic(STREAM_TOPIC).publish(chunk, force_flush=True) + + async with client: + runtime = set_langgraph_config( + input.langgraph_config, stream_writer=stream_writer + ) + kwargs = dict(input.kwargs) + if accepts_runtime: + kwargs["runtime"] = runtime + try: + if iscoroutinefunction(func): + result = await func(*input.args, **kwargs) + else: + result = func(*input.args, **kwargs) + if isinstance(result, Command): + return ActivityOutput(langgraph_command=result) + return ActivityOutput(result=result) + except GraphInterrupt as e: + return ActivityOutput(langgraph_interrupts=e.args[0]) return wrapper diff --git a/temporalio/contrib/langgraph/_langgraph_config.py b/temporalio/contrib/langgraph/_langgraph_config.py index 4cc529477..cad9ce5ee 100644 --- a/temporalio/contrib/langgraph/_langgraph_config.py +++ b/temporalio/contrib/langgraph/_langgraph_config.py @@ -3,7 +3,7 @@ # pyright: reportMissingTypeStubs=false import dataclasses -from typing import Any +from typing import Any, Callable from langchain_core.runnables.config import var_child_runnable_config from langgraph._internal._constants import ( @@ -93,11 +93,19 @@ def get_langgraph_config() -> dict[str, Any]: } -def set_langgraph_config(config: dict[str, Any]) -> Runtime: +def set_langgraph_config( + config: dict[str, Any], + *, + stream_writer: Callable[[Any], None] | None = None, +) -> Runtime: """Restore a LangGraph runnable config from a serialized dict. Returns the reconstructed Runtime so callers can re-inject it into the user function's kwargs without needing to know the configurable layout. + + If ``stream_writer`` is provided, it replaces the default no-op writer + in the reconstructed Runtime, so ``get_stream_writer()`` inside a node + delivers chunks through the caller's sink. """ configurable = config.get("configurable") or {} scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) or {} @@ -112,7 +120,7 @@ def get_null_resume(consume: bool = False) -> Any: execution_info_dict = config.get("execution_info") runtime = Runtime( context=config.get("context"), - stream_writer=lambda _: None, + stream_writer=stream_writer or (lambda _: None), previous=config.get("previous"), execution_info=( ExecutionInfo(**execution_info_dict) if execution_info_dict else None diff --git a/tests/contrib/langgraph/test_functional_runtime.py b/tests/contrib/langgraph/test_functional_runtime.py new file mode 100644 index 000000000..8619bb10d --- /dev/null +++ b/tests/contrib/langgraph/test_functional_runtime.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import sys +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +import pytest + +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="LangGraph Functional API requires Python >= 3.11 for async context propagation", +) + +from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] + entrypoint as lg_entrypoint, +) +from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] +from langgraph.runtime import get_runtime +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph import LangGraphPlugin, entrypoint +from temporalio.worker import Worker + + +class Context(TypedDict): + user_id: str + + +@task +async def read_user_id() -> str: + runtime = get_runtime(Context) + return runtime.context["user_id"] + + +@lg_entrypoint() +async def read_user_id_entrypoint(_: str) -> dict[str, str]: + user_id = await read_user_id() + return {"user_id": user_id} + + +@workflow.defn +class FunctionalRuntimeContextWorkflow: + def __init__(self) -> None: + self.app = entrypoint("read_user_id") + + @workflow.run + async def run(self, user_id: str) -> Any: + return await self.app.ainvoke("", context=Context(user_id=user_id)) + + +async def test_functional_runtime_context(client: Client) -> None: + task_queue = f"functional-runtime-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[FunctionalRuntimeContextWorkflow], + plugins=[ + LangGraphPlugin( + entrypoints={"read_user_id": read_user_id_entrypoint}, + tasks=[read_user_id], + activity_options={"read_user_id": {"execute_in": "activity"}}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + result = await client.execute_workflow( + FunctionalRuntimeContextWorkflow.run, + "user-123", + id=f"test-functional-runtime-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"user_id": "user-123"} diff --git a/tests/contrib/langgraph/test_runtime.py b/tests/contrib/langgraph/test_runtime.py new file mode 100644 index 000000000..90104b1b2 --- /dev/null +++ b/tests/contrib/langgraph/test_runtime.py @@ -0,0 +1,66 @@ +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] +from langgraph.runtime import Runtime +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph import LangGraphPlugin, graph +from temporalio.worker import Worker + + +class Context(TypedDict): + user_id: str + + +class State(TypedDict): + user_id: str + + +async def read_user_id(state: State, runtime: Runtime[Context]) -> dict[str, str]: + return {"user_id": runtime.context["user_id"]} + + +@workflow.defn +class RuntimeContextWorkflow: + def __init__(self) -> None: + self.app = graph("my-graph").compile() + + @workflow.run + async def run(self, user_id: str) -> Any: + return await self.app.ainvoke( + {"user_id": ""}, context=Context(user_id=user_id) + ) + + +async def test_runtime_context(client: Client): + g = StateGraph(State, context_schema=Context) + g.add_node("read_user_id", read_user_id, metadata={"execute_in": "activity"}) + g.add_edge(START, "read_user_id") + + task_queue = f"runtime-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[RuntimeContextWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"my-graph": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + result = await client.execute_workflow( + RuntimeContextWorkflow.run, + "user-123", + id=f"test-runtime-{uuid4()}", + task_queue=task_queue, + ) + + assert result == {"user_id": "user-123"} diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index f47feffee..f34b28f2e 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -1,13 +1,16 @@ +import asyncio from datetime import timedelta from typing import Any from uuid import uuid4 +from langgraph.config import get_stream_writer # pyright: ignore[reportMissingTypeStubs] from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph import LangGraphPlugin, graph +from temporalio.contrib.langgraph import STREAM_TOPIC, LangGraphPlugin, graph +from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient from temporalio.worker import Worker @@ -66,3 +69,149 @@ async def test_streaming(client: Client): ) assert chunks == [{"node_a": {"value": "a"}}, {"node_b": {"value": "ab"}}] + + +# --------------------------------------------------------------------------- +# Streaming via WorkflowStream: stream_writer inside an activity-wrapped node +# publishes back to the owning workflow, an external client subscribes. +# --------------------------------------------------------------------------- + +TOKENS = ["alpha", "beta", "gamma"] + + +async def token_node(state: State) -> dict[str, str]: + writer = get_stream_writer() + for token in TOKENS: + writer({"token": token}) + writer({"done": True}) + return {"value": "".join(TOKENS)} + + +@workflow.defn +class StreamingWorkflowStreamsWorkflow: + def __init__(self) -> None: + self.stream = WorkflowStream() + self.app = graph("streaming-ws").compile() + + @workflow.run + async def run(self, input: str) -> str: + result = await self.app.ainvoke({"value": input}) + return result["value"] + + +async def test_streaming_via_workflow_streams(client: Client): + g = StateGraph(State) + g.add_node("token_node", token_node, metadata={"execute_in": "activity"}) + g.add_edge(START, "token_node") + + task_queue = f"streaming-ws-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[StreamingWorkflowStreamsWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"streaming-ws": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + handle = await client.start_workflow( + StreamingWorkflowStreamsWorkflow.run, + "", + id=f"test-streaming-ws-{uuid4()}", + task_queue=task_queue, + ) + + ws_client = WorkflowStreamClient.create(client, handle.id) + chunks: list[dict[str, Any]] = [] + async with asyncio.timeout(15.0): + async for item in ws_client.topic(STREAM_TOPIC, type=dict).subscribe( + from_offset=0, poll_cooldown=timedelta(0), + ): + chunks.append(item.data) + if chunks[-1].get("done"): + break + + result = await handle.result() + + assert result == "alphabetagamma" + assert chunks == [ + {"token": "alpha"}, + {"token": "beta"}, + {"token": "gamma"}, + {"done": True}, + ] + + +# --------------------------------------------------------------------------- +# Workflow-side publish: iterate astream() in the workflow and forward each +# chunk via self.stream.topic("astream").publish(...) so external subscribers +# see node-level progress alongside any activity-emitted tokens. +# --------------------------------------------------------------------------- + + +@workflow.defn +class AstreamPublishWorkflow: + def __init__(self) -> None: + self.stream = WorkflowStream() + self.app = graph("astream-publish").compile() + + @workflow.run + async def run(self, input: str) -> str: + topic = self.stream.topic("astream") + async for chunk in self.app.astream({"value": input}): + topic.publish(chunk) + topic.publish({"done": True}) + return "done" + + +async def test_workflow_publishes_astream_chunks(client: Client): + g = StateGraph(State) + g.add_node("node_a", node_a, metadata={"execute_in": "activity"}) + g.add_node("node_b", node_b, metadata={"execute_in": "activity"}) + g.add_edge(START, "node_a") + g.add_edge("node_a", "node_b") + + task_queue = f"astream-publish-{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[AstreamPublishWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"astream-publish": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + ) + ], + ): + handle = await client.start_workflow( + AstreamPublishWorkflow.run, + "", + id=f"test-astream-publish-{uuid4()}", + task_queue=task_queue, + ) + + ws_client = WorkflowStreamClient.create(client, handle.id) + chunks: list[dict[str, Any]] = [] + async with asyncio.timeout(15.0): + async for item in ws_client.topic("astream", type=dict).subscribe( + from_offset=0, poll_cooldown=timedelta(0), + ): + chunks.append(item.data) + if chunks[-1].get("done"): + break + + await handle.result() + + assert chunks == [ + {"node_a": {"value": "a"}}, + {"node_b": {"value": "ab"}}, + {"done": True}, + ] From 0b3f2bdce6815e1eaa2560a467ab7dcf1d16b53c Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 12:48:56 -0700 Subject: [PATCH 02/15] Trim obvious comments in langgraph activity/config --- temporalio/contrib/langgraph/_activity.py | 15 ++------------- temporalio/contrib/langgraph/_langgraph_config.py | 4 ---- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/temporalio/contrib/langgraph/_activity.py b/temporalio/contrib/langgraph/_activity.py index 8b86e7bf1..f7ee97c0e 100644 --- a/temporalio/contrib/langgraph/_activity.py +++ b/temporalio/contrib/langgraph/_activity.py @@ -22,13 +22,7 @@ from temporalio.contrib.workflow_streams import WorkflowStreamClient STREAM_TOPIC = "langgraph_stream" -"""Topic that LangGraph node stream_writer chunks are published to. - -Each chunk is encoded by the configured payload converter and delivered -to the parent workflow's :class:`WorkflowStream`. Subscribers receive -already-decoded values via ``WorkflowStreamClient.subscribe`` — -``item.data`` is the chunk, no manual decoding required. -""" +"""Workflow stream topic that LangGraph stream_writer publishes to.""" # Per-run dedupe so we only warn once when a user passes a Store via # graph.compile(store=...) / @entrypoint(store=...). Cleared by @@ -71,15 +65,10 @@ def wrap_activity( async def wrapper(input: ActivityInput) -> ActivityOutput: # Back get_stream_writer() with a WorkflowStreamClient targeting the # owning workflow. Chunks emitted inside the node are signaled back - # to the workflow's WorkflowStream. If the node never calls - # writer(...), the buffer stays empty and the final flush is a - # no-op — no signals are sent. + # to the workflow's WorkflowStream. client = WorkflowStreamClient.from_within_activity() def stream_writer(chunk: Any) -> None: - # force_flush=True wakes the flusher to send immediately instead - # of waiting for the batch_interval timer; rapid writer calls - # still coalesce into a single signal while in-flight. client.topic(STREAM_TOPIC).publish(chunk, force_flush=True) async with client: diff --git a/temporalio/contrib/langgraph/_langgraph_config.py b/temporalio/contrib/langgraph/_langgraph_config.py index cad9ce5ee..90c6c810d 100644 --- a/temporalio/contrib/langgraph/_langgraph_config.py +++ b/temporalio/contrib/langgraph/_langgraph_config.py @@ -102,10 +102,6 @@ def set_langgraph_config( Returns the reconstructed Runtime so callers can re-inject it into the user function's kwargs without needing to know the configurable layout. - - If ``stream_writer`` is provided, it replaces the default no-op writer - in the reconstructed Runtime, so ``get_stream_writer()`` inside a node - delivers chunks through the caller's sink. """ configurable = config.get("configurable") or {} scratchpad = configurable.get(CONFIG_KEY_SCRATCHPAD) or {} From 0d1a53646b238b95a160eb4113d88d4175e1a4ab Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 12:49:22 -0700 Subject: [PATCH 03/15] Remove unrelated runtime tests --- .../langgraph/test_functional_runtime.py | 79 ------------------- tests/contrib/langgraph/test_runtime.py | 66 ---------------- 2 files changed, 145 deletions(-) delete mode 100644 tests/contrib/langgraph/test_functional_runtime.py delete mode 100644 tests/contrib/langgraph/test_runtime.py diff --git a/tests/contrib/langgraph/test_functional_runtime.py b/tests/contrib/langgraph/test_functional_runtime.py deleted file mode 100644 index 8619bb10d..000000000 --- a/tests/contrib/langgraph/test_functional_runtime.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -import sys -from datetime import timedelta -from typing import Any -from uuid import uuid4 - -import pytest - -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 11), - reason="LangGraph Functional API requires Python >= 3.11 for async context propagation", -) - -from langgraph.func import ( # pyright: ignore[reportMissingTypeStubs] - entrypoint as lg_entrypoint, -) -from langgraph.func import task # pyright: ignore[reportMissingTypeStubs] -from langgraph.runtime import get_runtime -from typing_extensions import TypedDict - -from temporalio import workflow -from temporalio.client import Client -from temporalio.contrib.langgraph import LangGraphPlugin, entrypoint -from temporalio.worker import Worker - - -class Context(TypedDict): - user_id: str - - -@task -async def read_user_id() -> str: - runtime = get_runtime(Context) - return runtime.context["user_id"] - - -@lg_entrypoint() -async def read_user_id_entrypoint(_: str) -> dict[str, str]: - user_id = await read_user_id() - return {"user_id": user_id} - - -@workflow.defn -class FunctionalRuntimeContextWorkflow: - def __init__(self) -> None: - self.app = entrypoint("read_user_id") - - @workflow.run - async def run(self, user_id: str) -> Any: - return await self.app.ainvoke("", context=Context(user_id=user_id)) - - -async def test_functional_runtime_context(client: Client) -> None: - task_queue = f"functional-runtime-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[FunctionalRuntimeContextWorkflow], - plugins=[ - LangGraphPlugin( - entrypoints={"read_user_id": read_user_id_entrypoint}, - tasks=[read_user_id], - activity_options={"read_user_id": {"execute_in": "activity"}}, - default_activity_options={ - "start_to_close_timeout": timedelta(seconds=10) - }, - ) - ], - ): - result = await client.execute_workflow( - FunctionalRuntimeContextWorkflow.run, - "user-123", - id=f"test-functional-runtime-{uuid4()}", - task_queue=task_queue, - ) - - assert result == {"user_id": "user-123"} diff --git a/tests/contrib/langgraph/test_runtime.py b/tests/contrib/langgraph/test_runtime.py deleted file mode 100644 index 90104b1b2..000000000 --- a/tests/contrib/langgraph/test_runtime.py +++ /dev/null @@ -1,66 +0,0 @@ -from datetime import timedelta -from typing import Any -from uuid import uuid4 - -from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] -from langgraph.runtime import Runtime -from typing_extensions import TypedDict - -from temporalio import workflow -from temporalio.client import Client -from temporalio.contrib.langgraph import LangGraphPlugin, graph -from temporalio.worker import Worker - - -class Context(TypedDict): - user_id: str - - -class State(TypedDict): - user_id: str - - -async def read_user_id(state: State, runtime: Runtime[Context]) -> dict[str, str]: - return {"user_id": runtime.context["user_id"]} - - -@workflow.defn -class RuntimeContextWorkflow: - def __init__(self) -> None: - self.app = graph("my-graph").compile() - - @workflow.run - async def run(self, user_id: str) -> Any: - return await self.app.ainvoke( - {"user_id": ""}, context=Context(user_id=user_id) - ) - - -async def test_runtime_context(client: Client): - g = StateGraph(State, context_schema=Context) - g.add_node("read_user_id", read_user_id, metadata={"execute_in": "activity"}) - g.add_edge(START, "read_user_id") - - task_queue = f"runtime-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[RuntimeContextWorkflow], - plugins=[ - LangGraphPlugin( - graphs={"my-graph": g}, - default_activity_options={ - "start_to_close_timeout": timedelta(seconds=10) - }, - ) - ], - ): - result = await client.execute_workflow( - RuntimeContextWorkflow.run, - "user-123", - id=f"test-runtime-{uuid4()}", - task_queue=task_queue, - ) - - assert result == {"user_id": "user-123"} From 59be7fa5cfbafec1cfc7f91209210d25cfff96f2 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 13:08:00 -0700 Subject: [PATCH 04/15] Tidy langgraph streaming tests --- tests/contrib/langgraph/test_streaming.py | 103 ++++++---------------- 1 file changed, 27 insertions(+), 76 deletions(-) diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index f34b28f2e..43e148081 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -3,7 +3,9 @@ from typing import Any from uuid import uuid4 -from langgraph.config import get_stream_writer # pyright: ignore[reportMissingTypeStubs] +from langgraph.config import ( + get_stream_writer, # pyright: ignore[reportMissingTypeStubs] +) from langgraph.graph import START, StateGraph # pyright: ignore[reportMissingTypeStubs] from typing_extensions import TypedDict @@ -18,73 +20,13 @@ class State(TypedDict): value: str -async def node_a(state: State) -> dict[str, str]: - return {"value": state["value"] + "a"} - - -async def node_b(state: State) -> dict[str, str]: - return {"value": state["value"] + "b"} - - -@workflow.defn -class StreamingWorkflow: - def __init__(self) -> None: - self.app = graph("streaming").compile() - - @workflow.run - async def run(self, input: str) -> Any: - chunks = [] - async for chunk in self.app.astream({"value": input}): - chunks.append(chunk) - return chunks - - -async def test_streaming(client: Client): - g = StateGraph(State) - g.add_node("node_a", node_a, metadata={"execute_in": "activity"}) - g.add_node("node_b", node_b, metadata={"execute_in": "activity"}) - g.add_edge(START, "node_a") - g.add_edge("node_a", "node_b") - - task_queue = f"streaming-{uuid4()}" - - async with Worker( - client, - task_queue=task_queue, - workflows=[StreamingWorkflow], - plugins=[ - LangGraphPlugin( - graphs={"streaming": g}, - default_activity_options={ - "start_to_close_timeout": timedelta(seconds=10) - }, - ) - ], - ): - chunks = await client.execute_workflow( - StreamingWorkflow.run, - "", - id=f"test-streaming-{uuid4()}", - task_queue=task_queue, - ) - - assert chunks == [{"node_a": {"value": "a"}}, {"node_b": {"value": "ab"}}] - - -# --------------------------------------------------------------------------- -# Streaming via WorkflowStream: stream_writer inside an activity-wrapped node -# publishes back to the owning workflow, an external client subscribes. -# --------------------------------------------------------------------------- - -TOKENS = ["alpha", "beta", "gamma"] - - async def token_node(state: State) -> dict[str, str]: + tokens = ["a", "b", "c"] writer = get_stream_writer() - for token in TOKENS: + for token in tokens: writer({"token": token}) writer({"done": True}) - return {"value": "".join(TOKENS)} + return {"value": state["value"] + "".join(tokens)} @workflow.defn @@ -128,21 +70,21 @@ async def test_streaming_via_workflow_streams(client: Client): ws_client = WorkflowStreamClient.create(client, handle.id) chunks: list[dict[str, Any]] = [] - async with asyncio.timeout(15.0): - async for item in ws_client.topic(STREAM_TOPIC, type=dict).subscribe( - from_offset=0, poll_cooldown=timedelta(0), - ): - chunks.append(item.data) - if chunks[-1].get("done"): - break + async for item in ws_client.topic(STREAM_TOPIC, type=dict).subscribe( + from_offset=0, + poll_cooldown=timedelta(0), + ): + chunks.append(item.data) + if chunks[-1].get("done"): + break result = await handle.result() - assert result == "alphabetagamma" + assert result == "abc" assert chunks == [ - {"token": "alpha"}, - {"token": "beta"}, - {"token": "gamma"}, + {"token": "a"}, + {"token": "b"}, + {"token": "c"}, {"done": True}, ] @@ -169,6 +111,14 @@ async def run(self, input: str) -> str: return "done" +async def node_a(state: State) -> dict[str, str]: + return {"value": state["value"] + "a"} + + +async def node_b(state: State) -> dict[str, str]: + return {"value": state["value"] + "b"} + + async def test_workflow_publishes_astream_chunks(client: Client): g = StateGraph(State) g.add_node("node_a", node_a, metadata={"execute_in": "activity"}) @@ -202,7 +152,8 @@ async def test_workflow_publishes_astream_chunks(client: Client): chunks: list[dict[str, Any]] = [] async with asyncio.timeout(15.0): async for item in ws_client.topic("astream", type=dict).subscribe( - from_offset=0, poll_cooldown=timedelta(0), + from_offset=0, + poll_cooldown=timedelta(0), ): chunks.append(item.data) if chunks[-1].get("done"): From a4c5b9b045bf9bed201d8014e31e4362cd51b908 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 13:19:07 -0700 Subject: [PATCH 05/15] don't store workflowstream --- tests/contrib/langgraph/test_streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 43e148081..78a22ebd0 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -32,7 +32,7 @@ async def token_node(state: State) -> dict[str, str]: @workflow.defn class StreamingWorkflowStreamsWorkflow: def __init__(self) -> None: - self.stream = WorkflowStream() + _ = WorkflowStream() self.app = graph("streaming-ws").compile() @workflow.run From 2eb07f325abed9a793ba58681c4b098681ca96f5 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 13:27:15 -0700 Subject: [PATCH 06/15] remove timeout --- tests/contrib/langgraph/test_streaming.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 78a22ebd0..4a01aeb5f 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -150,14 +150,13 @@ async def test_workflow_publishes_astream_chunks(client: Client): ws_client = WorkflowStreamClient.create(client, handle.id) chunks: list[dict[str, Any]] = [] - async with asyncio.timeout(15.0): - async for item in ws_client.topic("astream", type=dict).subscribe( - from_offset=0, - poll_cooldown=timedelta(0), - ): - chunks.append(item.data) - if chunks[-1].get("done"): - break + async for item in ws_client.topic("astream", type=dict).subscribe( + from_offset=0, + poll_cooldown=timedelta(0), + ): + chunks.append(item.data) + if chunks[-1].get("done"): + break await handle.result() From 9ce797819a80ebffb5d0d45a09973e81e56d9b91 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 13:32:31 -0700 Subject: [PATCH 07/15] fix lint --- tests/contrib/langgraph/test_streaming.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 4a01aeb5f..c1735f47f 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -1,4 +1,3 @@ -import asyncio from datetime import timedelta from typing import Any from uuid import uuid4 From 642b4b9bb7864fceef77461652c4109990841527 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 14:31:14 -0700 Subject: [PATCH 08/15] Make langgraph streaming opt-in via streaming_topic --- temporalio/contrib/langgraph/__init__.py | 2 -- temporalio/contrib/langgraph/_activity.py | 25 ++++++++++++----------- temporalio/contrib/langgraph/_plugin.py | 12 ++++++++++- tests/contrib/langgraph/test_streaming.py | 5 +++-- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/temporalio/contrib/langgraph/__init__.py b/temporalio/contrib/langgraph/__init__.py index f7f571bcc..e9aaf5605 100644 --- a/temporalio/contrib/langgraph/__init__.py +++ b/temporalio/contrib/langgraph/__init__.py @@ -10,7 +10,6 @@ API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). """ -from temporalio.contrib.langgraph._activity import STREAM_TOPIC from temporalio.contrib.langgraph._plugin import ( LangGraphPlugin, cache, @@ -20,7 +19,6 @@ __all__ = [ "LangGraphPlugin", - "STREAM_TOPIC", "cache", "entrypoint", "graph", diff --git a/temporalio/contrib/langgraph/_activity.py b/temporalio/contrib/langgraph/_activity.py index f7ee97c0e..92f2dca0c 100644 --- a/temporalio/contrib/langgraph/_activity.py +++ b/temporalio/contrib/langgraph/_activity.py @@ -2,6 +2,7 @@ from collections.abc import Awaitable from dataclasses import dataclass +from datetime import timedelta from inspect import iscoroutinefunction, signature from typing import Any, Callable @@ -21,9 +22,6 @@ ) from temporalio.contrib.workflow_streams import WorkflowStreamClient -STREAM_TOPIC = "langgraph_stream" -"""Workflow stream topic that LangGraph stream_writer publishes to.""" - # Per-run dedupe so we only warn once when a user passes a Store via # graph.compile(store=...) / @entrypoint(store=...). Cleared by # LangGraphInterceptor.execute_workflow on workflow exit. @@ -55,6 +53,9 @@ class ActivityOutput: def wrap_activity( func: Callable, + *, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), ) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: """Wrap a function as a Temporal activity that handles LangGraph config and interrupts.""" # Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks @@ -63,15 +64,7 @@ def wrap_activity( accepts_runtime = "runtime" in signature(func).parameters async def wrapper(input: ActivityInput) -> ActivityOutput: - # Back get_stream_writer() with a WorkflowStreamClient targeting the - # owning workflow. Chunks emitted inside the node are signaled back - # to the workflow's WorkflowStream. - client = WorkflowStreamClient.from_within_activity() - - def stream_writer(chunk: Any) -> None: - client.topic(STREAM_TOPIC).publish(chunk, force_flush=True) - - async with client: + async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput: runtime = set_langgraph_config( input.langgraph_config, stream_writer=stream_writer ) @@ -89,6 +82,14 @@ def stream_writer(chunk: Any) -> None: except GraphInterrupt as e: return ActivityOutput(langgraph_interrupts=e.args[0]) + if streaming_topic is None: + return await run(stream_writer=None) + async with WorkflowStreamClient.from_within_activity( + batch_interval=streaming_batch_interval, + ) as client: + topic = client.topic(streaming_topic) + return await run(stream_writer=topic.publish) + return wrapper diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index a624f62a7..828cea8d5 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -8,6 +8,7 @@ import sys import warnings from dataclasses import replace +from datetime import timedelta from typing import Any, Callable from langgraph._internal._runnable import RunnableCallable @@ -58,6 +59,8 @@ def __init__( # TODO: Remove activity_options when we have support for @task(metadata=...) activity_options: dict[str, dict[str, Any]] | None = None, default_activity_options: dict[str, Any] | None = None, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), ): """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" if sys.version_info < (3, 11): @@ -79,6 +82,8 @@ def __init__( ) self.activities: list = [] + self._streaming_topic = streaming_topic + self._streaming_batch_interval = streaming_batch_interval # Graph API: Wrap graph nodes as Temporal Activities. if graphs: @@ -197,7 +202,12 @@ def execute( execute_in = opts.pop("execute_in") if execute_in == "activity": - a = activity.defn(name=activity_name)(wrap_activity(func)) + wrapped = wrap_activity( + func, + streaming_topic=self._streaming_topic, + streaming_batch_interval=self._streaming_batch_interval, + ) + a = activity.defn(name=activity_name)(wrapped) self.activities.append(a) return wrap_execute_activity(a, task_id=task_id(func), **opts) elif execute_in == "workflow": diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index c1735f47f..6d3987cd0 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -10,7 +10,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.langgraph import STREAM_TOPIC, LangGraphPlugin, graph +from temporalio.contrib.langgraph import LangGraphPlugin, graph from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient from temporalio.worker import Worker @@ -57,6 +57,7 @@ async def test_streaming_via_workflow_streams(client: Client): default_activity_options={ "start_to_close_timeout": timedelta(seconds=10) }, + streaming_topic="tokens", ) ], ): @@ -69,7 +70,7 @@ async def test_streaming_via_workflow_streams(client: Client): ws_client = WorkflowStreamClient.create(client, handle.id) chunks: list[dict[str, Any]] = [] - async for item in ws_client.topic(STREAM_TOPIC, type=dict).subscribe( + async for item in ws_client.topic("tokens", type=dict).subscribe( from_offset=0, poll_cooldown=timedelta(0), ): From 04ea6653a4d84c3a1f6e207d25b95dc40ecc57e2 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 1 May 2026 14:36:29 -0700 Subject: [PATCH 09/15] add streaming support disclaimer --- temporalio/contrib/langgraph/_plugin.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index 828cea8d5..54e9be739 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -62,7 +62,12 @@ def __init__( streaming_topic: str | None = None, streaming_batch_interval: timedelta = timedelta(milliseconds=100), ): - """Initialize the LangGraph plugin with graphs, entrypoints, and tasks.""" + """Initialize the LangGraph plugin with graphs, entrypoints, and tasks. + + .. warning:: + Streaming support is experimental and may change in + future versions. + """ if sys.version_info < (3, 11): warnings.warn( # type: ignore[reportUnreachable] "LangGraphPlugin requires Python >= 3.11 for full async support. " From 5e7881aa55dfa1e9e4d5650b9f0fa6337f986a32 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 4 May 2026 13:05:11 -0700 Subject: [PATCH 10/15] mention streaming in readme --- temporalio/contrib/langgraph/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index 7c41b5da7..fdb230131 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -143,6 +143,10 @@ await g.ainvoke({...}, context=Context(user_id="alice")) Your `context` object must be serializable by the configured Temporal payload converter, since it crosses the Activity boundary. +## Streaming + +When `streaming_topic` is set on `LangGraphPlugin`, calls to `stream_writer` leverage Temporal [Workflow Streams](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Async nodes are recommended for this feature. + ## Tracing We recommend the [Temporal LangSmith Plugin](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/langsmith) to trace your LangGraph Workflows and Activities. From 32c2c2eb815cffc4a634c306bda4191436c4fcc1 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 5 May 2026 14:03:25 -0700 Subject: [PATCH 11/15] Validate WorkflowStream registration when streaming_topic is set The LangGraph interceptor now checks at workflow start that a WorkflowStream has been registered (via the publish signal handler) when the plugin was configured with streaming_topic. Misconfigured workflows fail fast with a clear error pointing at @workflow.init, instead of silently producing no-op streams. --- temporalio/contrib/langgraph/_interceptor.py | 16 ++++++++++++++++ temporalio/contrib/langgraph/_plugin.py | 6 +++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/temporalio/contrib/langgraph/_interceptor.py b/temporalio/contrib/langgraph/_interceptor.py index fd583c052..f68d9d45d 100644 --- a/temporalio/contrib/langgraph/_interceptor.py +++ b/temporalio/contrib/langgraph/_interceptor.py @@ -11,6 +11,7 @@ from temporalio import workflow from temporalio.contrib.langgraph._activity import clear_store_warning +from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL from temporalio.worker import ( ExecuteWorkflowInput, Interceptor, @@ -30,10 +31,12 @@ def __init__( self, graphs: dict[str, StateGraph[Any, Any, Any, Any]], entrypoints: dict[str, Pregel[Any, Any, Any, Any]], + streaming_topic: str | None = None, ) -> None: """Initialize with the graphs and entrypoints to scope to each workflow run.""" self._graphs = graphs self._entrypoints = entrypoints + self._streaming_topic = streaming_topic def workflow_interceptor_class( self, input: WorkflowInterceptorClassInput @@ -41,6 +44,7 @@ def workflow_interceptor_class( """Return the inbound interceptor class used to scope graphs per run.""" graphs = self._graphs entrypoints = self._entrypoints + streaming_topic = self._streaming_topic class Inbound(WorkflowInboundInterceptor): def init(self, outbound: WorkflowOutboundInterceptor) -> None: @@ -50,6 +54,18 @@ def init(self, outbound: WorkflowOutboundInterceptor) -> None: super().init(outbound) async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + if ( + streaming_topic is not None + and workflow.get_signal_handler(_PUBLISH_SIGNAL) is None + ): + raise RuntimeError( + f"LangGraphPlugin was configured with " + f"streaming_topic={streaming_topic!r}, but workflow " + f"{workflow.info().workflow_type!r} did not register a " + f"WorkflowStream. Construct WorkflowStream() in the " + f"workflow's @workflow.init (i.e. __init__) method so " + f"streaming activities can publish to it." + ) try: return await self.next.execute_workflow(input) finally: diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index 54e9be739..6b010d6f0 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -193,7 +193,11 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "langchain.LangGraphPlugin", activities=self.activities, workflow_runner=workflow_runner, - interceptors=[LangGraphInterceptor(graphs or {}, entrypoints or {})], + interceptors=[ + LangGraphInterceptor( + graphs or {}, entrypoints or {}, streaming_topic=streaming_topic + ) + ], ) def execute( From a62677a05f212d544d0249685c5f3865ef67265b Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 5 May 2026 14:57:45 -0700 Subject: [PATCH 12/15] Stream from workflow-side LangGraph nodes via in-workflow WorkflowStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wrap execute_in='workflow' nodes with wrap_workflow(), which mirrors wrap_activity() and (when streaming_topic is set) overrides the LangGraph Runtime's stream_writer to publish synchronously to the in-workflow WorkflowStream — no signal round-trip. Parametrized the streaming test over execute_in to cover both paths. --- temporalio/contrib/langgraph/_plugin.py | 3 +- temporalio/contrib/langgraph/_workflow.py | 62 +++++++++++++++++++++++ tests/contrib/langgraph/test_streaming.py | 6 ++- 3 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 temporalio/contrib/langgraph/_workflow.py diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index 6b010d6f0..a2c000685 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -27,6 +27,7 @@ set_task_cache, task_id, ) +from temporalio.contrib.langgraph._workflow import wrap_workflow from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -220,7 +221,7 @@ def execute( self.activities.append(a) return wrap_execute_activity(a, task_id=task_id(func), **opts) elif execute_in == "workflow": - return func + return wrap_workflow(func, streaming_topic=self._streaming_topic) else: raise ValueError(f"Invalid execute_in value: {execute_in}") diff --git a/temporalio/contrib/langgraph/_workflow.py b/temporalio/contrib/langgraph/_workflow.py new file mode 100644 index 000000000..67bfd4f68 --- /dev/null +++ b/temporalio/contrib/langgraph/_workflow.py @@ -0,0 +1,62 @@ +"""Workflow-side wrappers for executing LangGraph nodes inline in a workflow.""" + +# pyright: reportMissingTypeStubs=false + +from __future__ import annotations + +import dataclasses +from collections.abc import Awaitable +from inspect import iscoroutinefunction +from typing import Any, Callable + +from langchain_core.runnables.config import var_child_runnable_config +from langgraph._internal._constants import CONFIG_KEY_RUNTIME + +from temporalio import workflow +from temporalio.contrib.workflow_streams._stream import _PUBLISH_SIGNAL + + +def wrap_workflow( + func: Callable[..., Any], + *, + streaming_topic: str | None = None, +) -> Callable[..., Awaitable[Any]]: + """Wrap a function as a workflow-side LangGraph node. + + Mirrors :func:`wrap_activity`: the outer wrapper resolves a stream + writer and passes it to an inner ``run`` that invokes the user + function with the writer installed. Workflow-side nodes publish + synchronously to the in-workflow ``WorkflowStream`` (no signal + round-trip); activity-side nodes go through ``WorkflowStreamClient``. + """ + + async def wrapper(*args: Any, **kwargs: Any) -> Any: + async def run(stream_writer: Callable[[Any], None] | None) -> Any: + token = None + if stream_writer is not None: + config = var_child_runnable_config.get() or {} + configurable = dict(config.get("configurable") or {}) + runtime = configurable.get(CONFIG_KEY_RUNTIME) + if runtime is not None: + configurable[CONFIG_KEY_RUNTIME] = dataclasses.replace( + runtime, stream_writer=stream_writer + ) + token = var_child_runnable_config.set( + {**config, "configurable": configurable} + ) + try: + if iscoroutinefunction(func): + return await func(*args, **kwargs) + return func(*args, **kwargs) + finally: + if token is not None: + var_child_runnable_config.reset(token) + + if streaming_topic is None: + return await run(stream_writer=None) + publish_handler = workflow.get_signal_handler(_PUBLISH_SIGNAL) + stream = getattr(publish_handler, "__self__") + topic = stream.topic(streaming_topic) + return await run(stream_writer=topic.publish) + + return wrapper diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index 6d3987cd0..bd8f8d159 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -2,6 +2,7 @@ from typing import Any from uuid import uuid4 +import pytest from langgraph.config import ( get_stream_writer, # pyright: ignore[reportMissingTypeStubs] ) @@ -40,9 +41,10 @@ async def run(self, input: str) -> str: return result["value"] -async def test_streaming_via_workflow_streams(client: Client): +@pytest.mark.parametrize("execute_in", ["activity", "workflow"]) +async def test_streaming_via_workflow_streams(client: Client, execute_in: str): g = StateGraph(State) - g.add_node("token_node", token_node, metadata={"execute_in": "activity"}) + g.add_node("token_node", token_node, metadata={"execute_in": execute_in}) g.add_edge(START, "token_node") task_queue = f"streaming-ws-{uuid4()}" From 5802d88e65f050fcb67535fc2ded2ff568425bfa Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 5 May 2026 15:20:05 -0700 Subject: [PATCH 13/15] Document streaming feature in README and plugin docstring Expand the README streaming section with a self-contained snippet (plugin, WorkflowStream in __init__, external subscriber loop), an explicit callout that streaming_topic only covers stream_mode='custom' with an astream() bridge example for other modes, and at-least-once retry semantics. Add an Args section to LangGraphPlugin's docstring covering all constructor parameters. --- temporalio/contrib/langgraph/README.md | 101 +++++++++++++++++++++++- temporalio/contrib/langgraph/_plugin.py | 43 ++++++++++ 2 files changed, 143 insertions(+), 1 deletion(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index fdb230131..c281e93c8 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -145,7 +145,106 @@ Your `context` object must be serializable by the configured Temporal payload co ## Streaming -When `streaming_topic` is set on `LangGraphPlugin`, calls to `stream_writer` leverage Temporal [Workflow Streams](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Async nodes are recommended for this feature. +When `streaming_topic` is set on `LangGraphPlugin`, calls to `langgraph.config.get_stream_writer()` inside a node publish to the named topic on the workflow's [`WorkflowStream`](https://github.com/temporalio/sdk-python/tree/main/temporalio/contrib/workflow_streams). Activity-side nodes publish via `WorkflowStreamClient` (a signal carrying batched items, controlled by `streaming_batch_interval`); workflow-side nodes publish synchronously to the in-workflow stream (no signal). External subscribers consume the stream with `WorkflowStreamClient.create(...).topic(...).subscribe(...)`. + +The workflow **must** construct `WorkflowStream()` in its `@workflow.init` (i.e. `__init__`) + +```python +from datetime import timedelta +from typing import Any + +from langgraph.config import get_stream_writer +from langgraph.graph import START, StateGraph +from typing_extensions import TypedDict + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.langgraph import LangGraphPlugin, graph +from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient +from temporalio.worker import Worker + + +class State(TypedDict): + value: str + + +async def token_node(state: State) -> dict[str, str]: + writer = get_stream_writer() + for token in ["hello", " ", "world"]: + writer({"token": token}) + writer({"done": True}) + return {"value": "hello world"} + + +@workflow.defn +class StreamingWorkflow: + def __init__(self) -> None: + # Required when streaming_topic is set on the plugin. + _ = WorkflowStream() + self.app = graph("streaming").compile() + + @workflow.run + async def run(self) -> str: + result = await self.app.ainvoke({"value": ""}) + return result["value"] + + +async def main(client: Client) -> None: + g = StateGraph(State) + g.add_node("token_node", token_node, metadata={"execute_in": "activity"}) + g.add_edge(START, "token_node") + + async with Worker( + client, + task_queue="streaming-tq", + workflows=[StreamingWorkflow], + plugins=[ + LangGraphPlugin( + graphs={"streaming": g}, + default_activity_options={ + "start_to_close_timeout": timedelta(seconds=10) + }, + streaming_topic="tokens", + ) + ], + ): + handle = await client.start_workflow( + StreamingWorkflow.run, id="streaming-wf", task_queue="streaming-tq" + ) + + ws_client = WorkflowStreamClient.create(client, handle.id) + async for item in ws_client.topic("tokens", type=dict).subscribe(from_offset=0): + print(item.data) + if item.data.get("done"): + break + + print(await handle.result()) +``` + +### What's covered, and what isn't + +`streaming_topic` wires up exactly **one** LangGraph stream mode: `stream_mode="custom"`, i.e. values written through `get_stream_writer()`. The other modes — `"messages"`, `"values"`, `"updates"`, `"debug"` — are **not** captured by `streaming_topic`. They aren't produced by node-side writers; LangGraph's orchestrator emits them as it walks the graph. The documented pattern is to **bridge `astream()` in the workflow** and republish each yielded chunk to a `WorkflowStream` topic yourself: + +```python +@workflow.defn +class AstreamBridge: + def __init__(self) -> None: + self.stream = WorkflowStream() + self.app = graph("g").compile() + + @workflow.run + async def run(self) -> None: + topic = self.stream.topic("astream") + async for chunk in self.app.astream({...}, stream_mode="messages"): + topic.publish(chunk) + topic.publish({"done": True}) +``` + +The two mechanisms compose. A workflow can both set `streaming_topic="tokens"` (so nodes' `get_stream_writer()` calls publish to `"tokens"`) **and** iterate `astream()` to republish orchestrator-level chunks to a separate topic (e.g. `"messages"`). External subscribers pick the topic that matches what they want. + +### Retry semantics + +Streaming has **at-least-once** delivery per activity attempt. When an activity-wrapped node retries (transient failure, worker crash, etc.), the user function re-runs from scratch and re-publishes its writes — earlier publishes from the failed attempt are not rolled back. Subscribers should be ready to see duplicates and recover idempotently (e.g. dedupe on a sequence id you include in each chunk, or treat the stream as advisory and rely on the workflow's final result for state). ## Tracing diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index a2c000685..33cafe8d2 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -48,6 +48,49 @@ class LangGraphPlugin(SimplePlugin): and tasks as Temporal Activities, giving your AI agent workflows durable execution, automatic retries, and timeouts. It supports both the LangGraph Graph API (``StateGraph``) and Functional API (``@entrypoint`` / ``@task``). + + Args: + graphs: Graph API graphs to make available to workflows, keyed by name. + Workflows retrieve them with :func:`graph` and call + ``.compile()`` to get a runnable. Each node's ``metadata`` must + include ``execute_in`` (``"activity"`` or ``"workflow"``) and + may include any kwarg accepted by + :func:`workflow.execute_activity` (e.g. ``start_to_close_timeout``, + ``retry_policy``). + entrypoints: Functional API entrypoints to make available to + workflows, keyed by name. Workflows retrieve them with + :func:`entrypoint`. + tasks: Functional API ``@task`` functions to wrap as Temporal + Activities. + activity_options: Per-task activity options for the Functional + API, keyed by task function name. Each entry must include + ``execute_in`` and may include any + :func:`workflow.execute_activity` kwarg. Used because LangGraph's + Functional API has no per-task ``metadata`` channel. + default_activity_options: Activity options applied to every + activity-bound node and task, overridable per-node (Graph API + ``metadata``) or per-task (``activity_options[name]``). + streaming_topic: When set, ``langgraph.config.get_stream_writer()`` + inside a node publishes to this topic on the workflow's + :class:`WorkflowStream`. The workflow must construct + ``WorkflowStream()`` in its ``@workflow.init`` (the plugin's + interceptor verifies this on workflow start). Nodes with + ``execute_in='activity'`` publish through + :class:`WorkflowStreamClient` (signal); nodes with + ``execute_in='workflow'`` publish synchronously to the + in-workflow stream (no signal). + streaming_batch_interval: How often the activity-side stream + client flushes buffered publishes into a single + ``__temporal_workflow_stream_publish`` signal. Has no effect + on workflow-side nodes (their publishes are synchronous + in-memory log appends). Lower values reduce streaming + latency at the cost of more signals (more workflow history + events); higher values amortize signal cost but make + chunks arrive in larger bursts. Default 100ms suits + interactive token streaming; raise to 250–1000ms for + non-interactive aggregation, lower toward 10–50ms only if + you've measured the latency need and accept the history + cost. """ def __init__( From 627ee3f983d08f87f623957043a99a00af517a18 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 5 May 2026 15:21:48 -0700 Subject: [PATCH 14/15] Drop compose-mechanisms paragraph from streaming README --- temporalio/contrib/langgraph/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/temporalio/contrib/langgraph/README.md b/temporalio/contrib/langgraph/README.md index c281e93c8..dafe598b7 100644 --- a/temporalio/contrib/langgraph/README.md +++ b/temporalio/contrib/langgraph/README.md @@ -240,8 +240,6 @@ class AstreamBridge: topic.publish({"done": True}) ``` -The two mechanisms compose. A workflow can both set `streaming_topic="tokens"` (so nodes' `get_stream_writer()` calls publish to `"tokens"`) **and** iterate `astream()` to republish orchestrator-level chunks to a separate topic (e.g. `"messages"`). External subscribers pick the topic that matches what they want. - ### Retry semantics Streaming has **at-least-once** delivery per activity attempt. When an activity-wrapped node retries (transient failure, worker crash, etc.), the user function re-runs from scratch and re-publishes its writes — earlier publishes from the failed attempt are not rolled back. Subscribers should be ready to see duplicates and recover idempotently (e.g. dedupe on a sequence id you include in each chunk, or treat the stream as advisory and rely on the workflow's final result for state). From bda13575b52d6da6fd61170fef24a89455621728 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 5 May 2026 15:40:58 -0700 Subject: [PATCH 15/15] Support sync nodes for streaming and execute_in='workflow' MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pick the raw user function from runnable.func instead of LangGraph's async runnable.afunc adapter, which wraps sync nodes in loop.run_in_executor — that's incompatible with the workflow event loop. wrap_activity now schedules sync funcs on a thread via asyncio.to_thread so the activity loop stays free for the streaming flusher, with stream_writer calls marshaled back to the loop thread to keep the workflow_streams client's asyncio.Event safe. Parametrize the streaming test over (execute_in, sync/async). --- temporalio/contrib/langgraph/_activity.py | 26 +++++++++++++++++------ temporalio/contrib/langgraph/_plugin.py | 2 +- tests/contrib/langgraph/test_streaming.py | 20 ++++++++++++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/temporalio/contrib/langgraph/_activity.py b/temporalio/contrib/langgraph/_activity.py index 92f2dca0c..5891eb81b 100644 --- a/temporalio/contrib/langgraph/_activity.py +++ b/temporalio/contrib/langgraph/_activity.py @@ -1,5 +1,6 @@ """Activity wrappers for executing LangGraph nodes and tasks.""" +import asyncio from collections.abc import Awaitable from dataclasses import dataclass from datetime import timedelta @@ -58,24 +59,35 @@ def wrap_activity( streaming_batch_interval: timedelta = timedelta(milliseconds=100), ) -> Callable[[ActivityInput], Awaitable[ActivityOutput]]: """Wrap a function as a Temporal activity that handles LangGraph config and interrupts.""" - # Graph nodes declare `runtime: Runtime[Ctx]` in their signature; tasks - # don't and instead reach for Runtime via get_runtime(). We re-inject the - # reconstructed Runtime only when the user function asks. - accepts_runtime = "runtime" in signature(func).parameters async def wrapper(input: ActivityInput) -> ActivityOutput: async def run(stream_writer: Callable[[Any], None] | None) -> ActivityOutput: + # Sync funcs run on a thread (so the loop keeps flushing the + # stream client mid-execution); marshal writer calls back to + # the loop thread because the client's flush event is an + # asyncio.Event and isn't safe to set off-thread. + effective_writer = stream_writer + if not iscoroutinefunction(func) and stream_writer is not None: + loop = asyncio.get_running_loop() + inner_writer = stream_writer + + def thread_safe_writer(value: Any) -> None: + loop.call_soon_threadsafe(inner_writer, value) + + effective_writer = thread_safe_writer + runtime = set_langgraph_config( - input.langgraph_config, stream_writer=stream_writer + input.langgraph_config, stream_writer=effective_writer ) kwargs = dict(input.kwargs) - if accepts_runtime: + if "runtime" in signature(func).parameters: kwargs["runtime"] = runtime + try: if iscoroutinefunction(func): result = await func(*input.args, **kwargs) else: - result = func(*input.args, **kwargs) + result = await asyncio.to_thread(func, *input.args, **kwargs) if isinstance(result, Command): return ActivityOutput(langgraph_command=result) return ActivityOutput(result=result) diff --git a/temporalio/contrib/langgraph/_plugin.py b/temporalio/contrib/langgraph/_plugin.py index 33cafe8d2..a1320d1a8 100644 --- a/temporalio/contrib/langgraph/_plugin.py +++ b/temporalio/contrib/langgraph/_plugin.py @@ -149,7 +149,7 @@ def __init__( runnable = node.runnable if not isinstance(runnable, RunnableCallable): raise ValueError(f"Node {node_name} must be a RunnableCallable") - user_func = runnable.afunc or runnable.func + user_func = runnable.func or runnable.afunc if user_func is None: raise ValueError(f"Node {node_name} must have a function") # Keep 'config' (for metadata/tags) and 'runtime' (for diff --git a/tests/contrib/langgraph/test_streaming.py b/tests/contrib/langgraph/test_streaming.py index bd8f8d159..ff00b8906 100644 --- a/tests/contrib/langgraph/test_streaming.py +++ b/tests/contrib/langgraph/test_streaming.py @@ -20,7 +20,16 @@ class State(TypedDict): value: str -async def token_node(state: State) -> dict[str, str]: +async def async_token_node(state: State) -> dict[str, str]: + tokens = ["a", "b", "c"] + writer = get_stream_writer() + for token in tokens: + writer({"token": token}) + writer({"done": True}) + return {"value": state["value"] + "".join(tokens)} + + +def sync_token_node(state: State) -> dict[str, str]: tokens = ["a", "b", "c"] writer = get_stream_writer() for token in tokens: @@ -42,9 +51,14 @@ async def run(self, input: str) -> str: @pytest.mark.parametrize("execute_in", ["activity", "workflow"]) -async def test_streaming_via_workflow_streams(client: Client, execute_in: str): +@pytest.mark.parametrize( + "node", [async_token_node, sync_token_node], ids=["async", "sync"] +) +async def test_streaming_via_workflow_streams( + client: Client, execute_in: str, node: Any +): g = StateGraph(State) - g.add_node("token_node", token_node, metadata={"execute_in": execute_in}) + g.add_node("token_node", node, metadata={"execute_in": execute_in}) g.add_edge(START, "token_node") task_queue = f"streaming-ws-{uuid4()}"