diff --git a/.gitignore b/.gitignore index f346d70d..b7a00880 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,8 @@ logs/ /.luarc.json +# worktrees +.worktrees/ + # setuptools_scm chatlas/_version.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 72c8f453..897a5246 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### New features + +* The `.stream()` and `.stream_async()` methods now yield `ContentThinking` objects (instead of plain strings) for thinking/reasoning content when `content="all"`. This allows downstream packages like shinychat to provide specific UI for thinking content. (#276) + ### Bug fixes * Fixed tool calling with Google thinking models (e.g., `gemini-3-flash-preview`) failing with a 400 `INVALID_ARGUMENT` error about a missing `thought_signature`. The signature is now preserved and forwarded in subsequent turns. (#274) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 709f5bc0..cccc4941 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -34,6 +34,7 @@ Content, ContentJson, ContentText, + ContentThinking, ContentToolRequest, ContentToolResult, ToolInfo, @@ -1155,7 +1156,7 @@ def stream( echo: EchoOptions = "none", data_model: Optional[type[BaseModel]] = None, kwargs: Optional[SubmitInputArgsT] = None, - ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ... + ) -> Generator[str | ContentThinking | ContentToolRequest | ContentToolResult, None, None]: ... def stream( self, @@ -1164,7 +1165,7 @@ def stream( echo: EchoOptions = "none", data_model: Optional[type[BaseModel]] = None, kwargs: Optional[SubmitInputArgsT] = None, - ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: + ) -> Generator[str | ContentThinking | ContentToolRequest | ContentToolResult, None, None]: """ Generate a response from the chat in a streaming fashion. @@ -1228,7 +1229,7 @@ class Person(BaseModel): ) def wrapper() -> Generator[ - str | ContentToolRequest | ContentToolResult, None, None + str | ContentThinking | ContentToolRequest | ContentToolResult, None, None ]: with display: for chunk in generator: @@ -1254,7 +1255,7 @@ async def stream_async( echo: EchoOptions = "none", data_model: Optional[type[BaseModel]] = None, kwargs: Optional[SubmitInputArgsT] = None, - ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ... + ) -> AsyncGenerator[str | ContentThinking | ContentToolRequest | ContentToolResult, None]: ... async def stream_async( self, @@ -1263,7 +1264,7 @@ async def stream_async( echo: EchoOptions = "none", data_model: Optional[type[BaseModel]] = None, kwargs: Optional[SubmitInputArgsT] = None, - ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: + ) -> AsyncGenerator[str | ContentThinking | ContentToolRequest | ContentToolResult, None]: """ Generate a response from the chat in a streaming fashion asynchronously. @@ -1309,9 +1310,12 @@ class Person(BaseModel): chat = ChatOpenAI() - chunks = [chunk async for chunk in await chat.stream_async( - "John is 25 years old", data_model=Person - )] + chunks = [ + chunk + async for chunk in await chat.stream_async( + "John is 25 years old", data_model=Person + ) + ] person = Person.model_validate_json("".join(chunks)) ``` """ @@ -1320,7 +1324,7 @@ class Person(BaseModel): display = self._markdown_display(echo=echo) async def wrapper() -> AsyncGenerator[ - str | ContentToolRequest | ContentToolResult, None + str | ContentThinking | ContentToolRequest | ContentToolResult, None ]: with display: async for chunk in self._chat_impl_async( @@ -2481,7 +2485,7 @@ def _chat_impl( stream: bool, kwargs: Optional[SubmitInputArgsT] = None, data_model: Optional[type[BaseModel]] = None, - ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ... + ) -> Generator[str | ContentThinking | ContentToolRequest | ContentToolResult, None, None]: ... def _chat_impl( self, @@ -2491,7 +2495,7 @@ def _chat_impl( stream: bool, kwargs: Optional[SubmitInputArgsT] = None, data_model: Optional[type[BaseModel]] = None, - ) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: + ) -> Generator[str | Content, None, None]: user_turn_result: UserTurn | None = user_turn while user_turn_result is not None: for chunk in self._submit_turns( @@ -2500,6 +2504,7 @@ def _chat_impl( stream=stream, data_model=data_model, kwargs=kwargs, + content_mode=content, ): yield chunk @@ -2548,7 +2553,7 @@ def _chat_impl_async( stream: bool, kwargs: Optional[SubmitInputArgsT] = None, data_model: Optional[type[BaseModel]] = None, - ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ... + ) -> AsyncGenerator[str | ContentThinking | ContentToolRequest | ContentToolResult, None]: ... async def _chat_impl_async( self, @@ -2558,7 +2563,7 @@ async def _chat_impl_async( stream: bool, kwargs: Optional[SubmitInputArgsT] = None, data_model: Optional[type[BaseModel]] = None, - ) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: + ) -> AsyncGenerator[str | Content, None]: user_turn_result: UserTurn | None = user_turn while user_turn_result is not None: async for chunk in self._submit_turns_async( @@ -2567,6 +2572,7 @@ async def _chat_impl_async( stream=stream, data_model=data_model, kwargs=kwargs, + content_mode=content, ): yield chunk @@ -2597,6 +2603,29 @@ async def _chat_impl_async( if all_results: user_turn_result = UserTurn(all_results) + @overload + def _submit_turns( + self, + user_turn: UserTurn, + echo: EchoOptions, + stream: bool, + data_model: type[BaseModel] | None = None, + kwargs: Optional[SubmitInputArgsT] = None, + content_mode: Literal["text"] = "text", + ) -> Generator[str, None, None]: ... + + @overload + def _submit_turns( + self, + user_turn: UserTurn, + echo: EchoOptions, + stream: bool, + data_model: type[BaseModel] | None = None, + kwargs: Optional[SubmitInputArgsT] = None, + *, + content_mode: Literal["all"], + ) -> Generator[str | Content, None, None]: ... + def _submit_turns( self, user_turn: UserTurn, @@ -2604,7 +2633,8 @@ def _submit_turns( stream: bool, data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, - ) -> Generator[str, None, None]: + content_mode: Literal["text", "all"] = "text", + ) -> Generator[str | Content, None, None]: if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") @@ -2630,10 +2660,17 @@ def emit(text: str | Content): result = None for chunk in response: - text = self.provider.stream_text(chunk) - if text: - emit(text) - yield text + content = self.provider.stream_content(chunk) + if content is not None: + text = content_text(content) + if text: + emit(text) + if content_mode == "all" and isinstance( + content, ContentThinking + ): + yield content + else: + yield text result = self.provider.stream_merge_chunks(result, chunk) turn = self.provider.stream_turn( @@ -2675,6 +2712,29 @@ def emit(text: str | Content): tokens_log(self.provider, turn.tokens) self._turns.extend([user_turn, turn]) + @overload + def _submit_turns_async( + self, + user_turn: UserTurn, + echo: EchoOptions, + stream: bool, + data_model: type[BaseModel] | None = None, + kwargs: Optional[SubmitInputArgsT] = None, + content_mode: Literal["text"] = "text", + ) -> AsyncGenerator[str, None]: ... + + @overload + def _submit_turns_async( + self, + user_turn: UserTurn, + echo: EchoOptions, + stream: bool, + data_model: type[BaseModel] | None = None, + kwargs: Optional[SubmitInputArgsT] = None, + *, + content_mode: Literal["all"], + ) -> AsyncGenerator[str | Content, None]: ... + async def _submit_turns_async( self, user_turn: UserTurn, @@ -2682,7 +2742,8 @@ async def _submit_turns_async( stream: bool, data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, - ) -> AsyncGenerator[str, None]: + content_mode: Literal["text", "all"] = "text", + ) -> AsyncGenerator[str | Content, None]: def emit(text: str | Content): self._echo_content(str(text)) @@ -2705,10 +2766,17 @@ def emit(text: str | Content): result = None async for chunk in response: - text = self.provider.stream_text(chunk) - if text: - emit(text) - yield text + content = self.provider.stream_content(chunk) + if content is not None: + text = content_text(content) + if text: + emit(text) + if content_mode == "all" and isinstance( + content, ContentThinking + ): + yield content + else: + yield text result = self.provider.stream_merge_chunks(result, chunk) turn = self.provider.stream_turn( @@ -3184,6 +3252,15 @@ class ToolFailureWarning(RuntimeWarning): warnings.simplefilter("always", ToolFailureWarning) +def content_text(content: Content) -> str: + """Extract displayable text from a Content object.""" + if isinstance(content, ContentThinking): + return content.thinking + if isinstance(content, ContentText): + return content.text + return str(content) + + def is_quarto(): return os.getenv("QUARTO_PYTHON", None) is not None diff --git a/chatlas/_provider.py b/chatlas/_provider.py index 5962e77c..b3e8d09a 100644 --- a/chatlas/_provider.py +++ b/chatlas/_provider.py @@ -15,7 +15,7 @@ from pydantic import BaseModel -from ._content import Content +from ._content import Content, ContentText, ContentThinking from ._tools import Tool, ToolBuiltIn from ._turn import AssistantTurn, Turn from ._typing_extensions import NotRequired, TypedDict @@ -226,7 +226,17 @@ async def chat_perform_async( ) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ... @abstractmethod - def stream_text(self, chunk: ChatCompletionChunkT) -> Optional[str]: ... + def stream_content(self, chunk: ChatCompletionChunkT) -> Optional["Content"]: ... + + def stream_text(self, chunk: ChatCompletionChunkT) -> Optional[str]: + content = self.stream_content(chunk) + if content is None: + return None + if isinstance(content, ContentThinking): + return content.thinking + if isinstance(content, ContentText): + return content.text + return str(content) @abstractmethod def stream_merge_chunks( diff --git a/chatlas/_provider_anthropic.py b/chatlas/_provider_anthropic.py index fca0c871..4a0d746d 100644 --- a/chatlas/_provider_anthropic.py +++ b/chatlas/_provider_anthropic.py @@ -463,12 +463,12 @@ def _structured_tool_call(**kwargs: Any): return kwargs_full - def stream_text(self, chunk) -> Optional[str]: + def stream_content(self, chunk) -> Optional[Content]: if chunk.type == "content_block_delta": if chunk.delta.type == "text_delta": - return chunk.delta.text + return ContentText.model_construct(text=chunk.delta.text) if chunk.delta.type == "thinking_delta": - return chunk.delta.thinking + return ContentThinking(thinking=chunk.delta.thinking) return None def stream_merge_chunks(self, completion, chunk): @@ -830,9 +830,7 @@ def _as_turn(self, completion: Message, has_data_model=False) -> AssistantTurn: extra = { "type": content.type, "tool_use_id": content.tool_use_id, - "content": [ - x.model_dump() for x in content.content - ] + "content": [x.model_dump() for x in content.content] if isinstance(content.content, list) else content.content.model_dump(), } diff --git a/chatlas/_provider_google.py b/chatlas/_provider_google.py index df5492b0..f512071e 100644 --- a/chatlas/_provider_google.py +++ b/chatlas/_provider_google.py @@ -14,6 +14,7 @@ ContentJson, ContentPDF, ContentText, + ContentThinking, ContentToolRequest, ContentToolResult, ) @@ -361,12 +362,23 @@ def _chat_perform_args( return kwargs_full - def stream_text(self, chunk) -> Optional[str]: - try: - # Errors if there is no text (e.g., tool request) - return chunk.text - except Exception: + def stream_content(self, chunk) -> Optional[Content]: + candidates = getattr(chunk, "candidates", None) + if not candidates: + return None + content = candidates[0].content + if content is None: + return None + parts = content.parts + if not parts: + return None + part = parts[0] + text = getattr(part, "text", None) + if text is None: return None + if getattr(part, "thought", None): + return ContentThinking(thinking=text) + return ContentText.model_construct(text=text) def stream_merge_chunks(self, completion, chunk): chunkd = chunk.model_dump() @@ -554,6 +566,8 @@ def _as_turn( if text: if has_data_model: contents.append(ContentJson(value=orjson.loads(text))) + elif part.get("thought"): + contents.append(ContentThinking(thinking=text)) else: contents.append(ContentText(text=text)) function_call = part.get("function_call") diff --git a/chatlas/_provider_openai.py b/chatlas/_provider_openai.py index 6623bd9a..2384d837 100644 --- a/chatlas/_provider_openai.py +++ b/chatlas/_provider_openai.py @@ -292,16 +292,17 @@ def _chat_perform_args( return kwargs_full - def stream_text(self, chunk): + def stream_content(self, chunk) -> Optional[Content]: if chunk.type == "response.output_text.delta": # https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text/delta - return chunk.delta + return ContentText.model_construct(text=chunk.delta) if chunk.type == "response.reasoning_summary_text.delta": # https://platform.openai.com/docs/api-reference/responses-streaming/response/reasoning_summary_text/delta - return chunk.delta + return ContentThinking(thinking=chunk.delta) if chunk.type == "response.reasoning_summary_text.done": + # Separator between reasoning summary and response text # https://platform.openai.com/docs/api-reference/responses-streaming/response/reasoning_summary_text/done - return "\n\n" + return ContentText.model_construct(text="\n\n") return None def stream_merge_chunks(self, completion, chunk): diff --git a/chatlas/_provider_openai_completions.py b/chatlas/_provider_openai_completions.py index 1b7c66dc..064042a2 100644 --- a/chatlas/_provider_openai_completions.py +++ b/chatlas/_provider_openai_completions.py @@ -192,10 +192,13 @@ def _chat_perform_args( return kwargs_full - def stream_text(self, chunk): + def stream_content(self, chunk) -> Optional[Content]: if not chunk.choices: return None - return chunk.choices[0].delta.content + text = chunk.choices[0].delta.content + if text is None: + return None + return ContentText.model_construct(text=text) def stream_merge_chunks(self, completion, chunk): chunkd = chunk.model_dump() diff --git a/chatlas/_provider_snowflake.py b/chatlas/_provider_snowflake.py index 2f2c1d58..cc22dc5a 100644 --- a/chatlas/_provider_snowflake.py +++ b/chatlas/_provider_snowflake.py @@ -356,13 +356,13 @@ def _complete_request( return req - def stream_text(self, chunk): + def stream_content(self, chunk) -> Optional[Content]: if not chunk.choices: return None delta = chunk.choices[0].delta if delta is None or "content" not in delta: return None - return delta["content"] + return ContentText.model_construct(text=delta["content"]) # Snowflake sort-of follows OpenAI/Anthropic streaming formats except they # don't have the critical "index" field in the delta that the merge logic