Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/fast_agent/llm/provider/openai/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from fast_agent.llm.usage_tracking import TurnUsage
from fast_agent.mcp.helpers.content_helpers import get_text
from fast_agent.types import LlmStopReason, PromptMessageExtended
from fast_agent.utils.reasoning_chunk_join import normalize_reasoning_delta

_logger = get_logger(__name__)

Expand Down Expand Up @@ -235,17 +236,22 @@ def _handle_reasoning_delta(
if not reasoning_text:
return reasoning_active

last_char = reasoning_segments[-1][-1] if reasoning_segments and reasoning_segments[-1] else None
normalized_text = normalize_reasoning_delta(last_char, reasoning_text)
if not normalized_text:
return reasoning_active

if reasoning_mode == "tags":
if not reasoning_active:
reasoning_active = True
self._notify_stream_listeners(StreamChunk(text=reasoning_text, is_reasoning=True))
reasoning_segments.append(reasoning_text)
self._notify_stream_listeners(StreamChunk(text=normalized_text, is_reasoning=True))
reasoning_segments.append(normalized_text)
return reasoning_active

if reasoning_mode in {"stream", "reasoning_content", "gpt_oss"}:
# Emit reasoning as-is
self._notify_stream_listeners(StreamChunk(text=reasoning_text, is_reasoning=True))
reasoning_segments.append(reasoning_text)
self._notify_stream_listeners(StreamChunk(text=normalized_text, is_reasoning=True))
reasoning_segments.append(normalized_text)
return reasoning_active

return reasoning_active
Expand Down
43 changes: 34 additions & 9 deletions src/fast_agent/llm/provider/openai/openresponses_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fast_agent.llm.provider.openai.streaming_utils import finalize_stream_response
from fast_agent.llm.provider.openai.tool_notifications import OpenAIToolNotificationMixin
from fast_agent.llm.stream_types import StreamChunk
from fast_agent.utils.reasoning_chunk_join import normalize_reasoning_delta

if TYPE_CHECKING:
from openai.types.responses import (
Expand Down Expand Up @@ -201,11 +202,19 @@ async def _process_stream(
part_type = getattr(part, "type", None)
part_text = getattr(part, "text", None)
if part_type in {"reasoning", "reasoning_text"} and part_text:
reasoning_segments.append(part_text)
last_char = (
reasoning_segments[-1][-1]
if reasoning_segments and reasoning_segments[-1]
else None
)
normalized_delta = normalize_reasoning_delta(last_char, part_text)
if not normalized_delta:
continue
reasoning_segments.append(normalized_delta)
self._notify_stream_listeners(
StreamChunk(text=part_text, is_reasoning=True)
StreamChunk(text=normalized_delta, is_reasoning=True)
)
reasoning_chars += len(part_text)
reasoning_chars += len(normalized_delta)
await self._emit_streaming_progress(
model=f"{model} (reasoning)",
new_total=reasoning_chars,
Expand All @@ -218,11 +227,19 @@ async def _process_stream(
"response.reasoning_summary.delta",
}:
if delta:
reasoning_segments.append(delta)
last_char = (
reasoning_segments[-1][-1]
if reasoning_segments and reasoning_segments[-1]
else None
)
normalized_delta = normalize_reasoning_delta(last_char, delta)
if not normalized_delta:
continue
reasoning_segments.append(normalized_delta)
self._notify_stream_listeners(
StreamChunk(text=delta, is_reasoning=True)
StreamChunk(text=normalized_delta, is_reasoning=True)
)
reasoning_chars += len(delta)
reasoning_chars += len(normalized_delta)
await self._emit_streaming_progress(
model=f"{model} (summary)",
new_total=reasoning_chars,
Expand All @@ -235,11 +252,19 @@ async def _process_stream(
"response.reasoning_text.delta",
}:
if delta:
reasoning_segments.append(delta)
last_char = (
reasoning_segments[-1][-1]
if reasoning_segments and reasoning_segments[-1]
else None
)
normalized_delta = normalize_reasoning_delta(last_char, delta)
if not normalized_delta:
continue
reasoning_segments.append(normalized_delta)
self._notify_stream_listeners(
StreamChunk(text=delta, is_reasoning=True)
StreamChunk(text=normalized_delta, is_reasoning=True)
)
reasoning_chars += len(delta)
reasoning_chars += len(normalized_delta)
await self._emit_streaming_progress(
model=f"{model} (reasoning)",
new_total=reasoning_chars,
Expand Down
15 changes: 12 additions & 3 deletions src/fast_agent/llm/provider/openai/responses_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fast_agent.llm.provider.openai.streaming_utils import finalize_stream_response
from fast_agent.llm.provider.openai.tool_notifications import OpenAIToolNotificationMixin
from fast_agent.llm.stream_types import StreamChunk
from fast_agent.utils.reasoning_chunk_join import normalize_reasoning_delta

_logger = get_logger(__name__)

Expand Down Expand Up @@ -156,11 +157,19 @@ def _close_tool(index: int, tool_use_id: str | None) -> None:
}:
delta = getattr(event, "delta", None)
if delta:
reasoning_segments.append(delta)
last_char = (
reasoning_segments[-1][-1]
if reasoning_segments and reasoning_segments[-1]
else None
)
normalized_delta = normalize_reasoning_delta(last_char, delta)
if not normalized_delta:
continue
reasoning_segments.append(normalized_delta)
self._notify_stream_listeners(
StreamChunk(text=delta, is_reasoning=True)
StreamChunk(text=normalized_delta, is_reasoning=True)
)
reasoning_chars += len(delta)
reasoning_chars += len(normalized_delta)
await self._emit_streaming_progress(
model=f"{model} (summary)",
new_total=reasoning_chars,
Expand Down
31 changes: 31 additions & 0 deletions src/fast_agent/utils/reasoning_chunk_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

_SENTENCE_PUNCTUATION = ".!?;:"
_MARKDOWN_PREFIXES = "\"`*["


def _looks_like_sentence_chunk(incoming: str) -> bool:
if not incoming:
return False
if " " not in incoming:
return False
first = incoming[0]
return first.isupper() or first in _MARKDOWN_PREFIXES


def normalize_reasoning_delta(last_char: str | None, incoming: str) -> str:
"""Normalize one reasoning delta without rebuilding the full accumulated text.

Keep the Codex-style append-only flow, but patch the specific broken case where
providers split natural-language reasoning into sentence chunks without a
separating space, e.g. "approach." + "Specifying session retrieval format".
"""
if not incoming:
return ""
if not last_char or last_char.isspace() or incoming[0].isspace():
return incoming
if last_char in _SENTENCE_PUNCTUATION and _looks_like_sentence_chunk(incoming):
return f" {incoming}"
if last_char.islower() and _looks_like_sentence_chunk(incoming):
return f" {incoming}"
return incoming
40 changes: 40 additions & 0 deletions tests/unit/fast_agent/test_reasoning_chunk_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from fast_agent.utils.reasoning_chunk_join import normalize_reasoning_delta


def test_normalize_reasoning_delta_inserts_space_after_sentence_break() -> None:
last_char = None
emitted = ""
parts = [
"approach.",
"Specifying session retrieval format",
"Selecting session retrieval method",
]

for part in parts:
delta = normalize_reasoning_delta(last_char, part)
emitted += delta
last_char = emitted[-1] if emitted else None

assert emitted == "approach. Specifying session retrieval format Selecting session retrieval method"


def test_normalize_reasoning_delta_preserves_contractions() -> None:
last_char = None
emitted = ""
for part in ["don", "'t do that"]:
delta = normalize_reasoning_delta(last_char, part)
emitted += delta
last_char = emitted[-1] if emitted else None

assert emitted == "don't do that"


def test_normalize_reasoning_delta_preserves_identifier_fragments() -> None:
last_char = None
emitted = ""
for part in ["session", "_id is required"]:
delta = normalize_reasoning_delta(last_char, part)
emitted += delta
last_char = emitted[-1] if emitted else None

assert emitted == "session_id is required"
Loading