diff --git a/chatkit/server.py b/chatkit/server.py index 9d82b64..c27a6af 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -697,7 +697,7 @@ async def _process_events( with agents_sdk_user_agent_override(): async for event in stream(): if isinstance(event, ThreadItemAddedEvent): - pending_items[event.item.id] = event.item + pending_items[event.item.id] = event.item.model_copy(deep=True) match event: case ThreadItemDoneEvent(): diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index d7ac167..deeccb5 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -73,11 +73,15 @@ ThreadUpdatedEvent, ThreadUpdateParams, ToolChoice, + ThoughtTask, UserMessageInput, UserMessageItem, UserMessageTextContent, WidgetItem, WidgetRootUpdated, + Workflow, + WorkflowItem, + WorkflowTaskAdded, ) from chatkit.widgets import Card, Text from tests._types import RequestContext @@ -345,6 +349,58 @@ def generate_item_id( ) +async def test_workflow_updates_not_applied_twice(): + async def responder( + thread: ThreadMetadata, input: UserMessageItem | None, context: Any + ) -> AsyncIterator[ThreadStreamEvent]: + workflow_item = WorkflowItem( + id="wf_1", + created_at=datetime.now(), + workflow=Workflow(type="reasoning", tasks=[]), + thread_id=thread.id, + ) + yield ThreadItemAddedEvent(item=workflow_item) + + task = ThoughtTask(content="First thought") + # Simulate responders that mutate the workflow item before emitting an update + workflow_item.workflow.tasks.append(task) + yield ThreadItemUpdatedEvent( + item_id=workflow_item.id, + update=WorkflowTaskAdded(task=task, task_index=0), + ) + + yield ThreadItemDoneEvent(item=workflow_item) + + with make_server(responder) as server: + events = await server.process_streaming( + ThreadsCreateReq( + params=ThreadCreateParams( + input=UserMessageInput( + content=[UserMessageTextContent(text="Hello")], + attachments=[], + inference_options=InferenceOptions(), + ) + ) + ) + ) + + thread = next(event.thread for event in events if event.type == "thread.created") + workflow_done = next( + event + for event in events + if event.type == "thread.item.done" and event.item.type == "workflow" + ) + + assert len(workflow_done.item.workflow.tasks) == 1 + assert workflow_done.item.workflow.tasks[0].content == "First thought" + + stored_workflow = await server.store.load_item( + thread.id, "wf_1", DEFAULT_CONTEXT + ) + assert len(stored_workflow.workflow.tasks) == 1 + assert stored_workflow.workflow.tasks[0].content == "First thought" + + async def test_flows_context_to_responder(): responder_context = None add_feedback_context = None