1919)
2020import uuid
2121
22- from ag_ui .core import AssistantMessage
23- from ag_ui .core import CustomEvent as AguiCustomEvent
24- from ag_ui .core import EventType as AguiEventType
25- from ag_ui .core import Message as AguiMessage
26- from ag_ui .core import MessagesSnapshotEvent
27- from ag_ui .core import RawEvent as AguiRawEvent
28- from ag_ui .core import (
29- RunErrorEvent ,
30- RunFinishedEvent ,
31- RunStartedEvent ,
32- StateDeltaEvent ,
33- StateSnapshotEvent ,
34- StepFinishedEvent ,
35- StepStartedEvent ,
36- SystemMessage ,
37- TextMessageContentEvent ,
38- TextMessageEndEvent ,
39- TextMessageStartEvent ,
40- )
41- from ag_ui .core import Tool as AguiTool
42- from ag_ui .core import ToolCall as AguiToolCall
43- from ag_ui .core import (
44- ToolCallArgsEvent ,
45- ToolCallEndEvent ,
46- ToolCallResultEvent ,
47- ToolCallStartEvent ,
48- )
49- from ag_ui .core import ToolMessage as AguiToolMessage
50- from ag_ui .core import UserMessage
51- from ag_ui .encoder import EventEncoder
22+ if TYPE_CHECKING :
23+ from ag_ui .core import (
24+ Message as AguiMessage ,
25+ )
26+ from ag_ui .encoder import EventEncoder
27+
5228from fastapi import APIRouter , Request
5329from fastapi .responses import StreamingResponse
5430import pydash
@@ -101,16 +77,20 @@ class StreamStateMachine:
10177 run_errored : bool = False
10278
10379 def end_all_tools (
104- self , encoder : EventEncoder , exclude : Optional [str ] = None
80+ self , encoder : " EventEncoder" , exclude : Optional [str ] = None
10581 ) -> Iterator [str ]:
82+ from ag_ui .core import ToolCallEndEvent
83+
10684 for tool_id , state in self .tool_call_states .items ():
10785 if exclude and tool_id == exclude :
10886 continue
10987 if state .started and not state .ended :
11088 yield encoder .encode (ToolCallEndEvent (tool_call_id = tool_id ))
11189 state .ended = True
11290
113- def ensure_text_started (self , encoder : EventEncoder ) -> Iterator [str ]:
91+ def ensure_text_started (self , encoder : "EventEncoder" ) -> Iterator [str ]:
92+ from ag_ui .core import TextMessageStartEvent
93+
11494 if not self .text .started or self .text .ended :
11595 if self .text .ended :
11696 self .text = TextState ()
@@ -123,7 +103,9 @@ def ensure_text_started(self, encoder: EventEncoder) -> Iterator[str]:
123103 self .text .started = True
124104 self .text .ended = False
125105
126- def end_text_if_open (self , encoder : EventEncoder ) -> Iterator [str ]:
106+ def end_text_if_open (self , encoder : "EventEncoder" ) -> Iterator [str ]:
107+ from ag_ui .core import TextMessageEndEvent
108+
127109 if self .text .started and not self .text .ended :
128110 yield encoder .encode (
129111 TextMessageEndEvent (message_id = self .text .message_id )
@@ -168,6 +150,8 @@ class AGUIProtocolHandler(BaseProtocolHandler):
168150 name = "ag-ui"
169151
170152 def __init__ (self , config : Optional [ServerConfig ] = None ):
153+ from ag_ui .encoder import EventEncoder
154+
171155 self ._config = config .agui if config else None
172156 self ._encoder = EventEncoder ()
173157
@@ -420,6 +404,18 @@ def _process_event_with_boundaries(
420404 """处理事件并注入边界事件"""
421405 import json
422406
407+ from ag_ui .core import CustomEvent as AguiCustomEvent
408+ from ag_ui .core import (
409+ RunErrorEvent ,
410+ StateDeltaEvent ,
411+ StateSnapshotEvent ,
412+ TextMessageContentEvent ,
413+ ToolCallArgsEvent ,
414+ ToolCallEndEvent ,
415+ ToolCallResultEvent ,
416+ ToolCallStartEvent ,
417+ )
418+
423419 # RAW 事件直接透传
424420 if event .event == EventType .RAW :
425421 raw_data = event .data .get ("raw" , "" )
@@ -703,7 +699,7 @@ def _process_event_with_boundaries(
703699
704700 def _convert_messages_for_snapshot (
705701 self , messages : List [Dict [str , Any ]]
706- ) -> List [AguiMessage ]:
702+ ) -> List [" AguiMessage" ]:
707703 """将消息列表转换为 ag-ui-protocol 格式
708704
709705 Args:
@@ -712,6 +708,10 @@ def _convert_messages_for_snapshot(
712708 Returns:
713709 ag-ui-protocol 消息列表
714710 """
711+ from ag_ui .core import AssistantMessage , SystemMessage
712+ from ag_ui .core import ToolMessage as AguiToolMessage
713+ from ag_ui .core import UserMessage
714+
715715 result = []
716716 for msg in messages :
717717 if not isinstance (msg , dict ):
@@ -779,6 +779,8 @@ async def _error_stream(self, message: str) -> AsyncIterator[str]:
779779 Yields:
780780 SSE 格式的错误事件
781781 """
782+ from ag_ui .core import RunErrorEvent , RunStartedEvent
783+
782784 thread_id = str (uuid .uuid4 ())
783785 run_id = str (uuid .uuid4 ())
784786
0 commit comments