Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,8 @@ logs/

/.luarc.json

# worktrees
.worktrees/

# setuptools_scm
chatlas/_version.py
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [UNRELEASED]

### New features

* The `.stream()` and `.stream_async()` methods now yield `ContentThinking` objects (instead of plain strings) for thinking/reasoning content when `content="all"`. This allows downstream packages like shinychat to provide specific UI for thinking content. (#276)

### Bug fixes

* Fixed tool calling with Google thinking models (e.g., `gemini-3-flash-preview`) failing with a 400 `INVALID_ARGUMENT` error about a missing `thought_signature`. The signature is now preserved and forwarded in subsequent turns. (#274)
Expand Down
123 changes: 100 additions & 23 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Content,
ContentJson,
ContentText,
ContentThinking,
ContentToolRequest,
ContentToolResult,
ToolInfo,
Expand Down Expand Up @@ -1155,7 +1156,7 @@ def stream(
echo: EchoOptions = "none",
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
) -> Generator[str | ContentThinking | ContentToolRequest | ContentToolResult, None, None]: ...

def stream(
self,
Expand All @@ -1164,7 +1165,7 @@ def stream(
echo: EchoOptions = "none",
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]:
) -> Generator[str | ContentThinking | ContentToolRequest | ContentToolResult, None, None]:
"""
Generate a response from the chat in a streaming fashion.

Expand Down Expand Up @@ -1228,7 +1229,7 @@ class Person(BaseModel):
)

def wrapper() -> Generator[
str | ContentToolRequest | ContentToolResult, None, None
str | ContentThinking | ContentToolRequest | ContentToolResult, None, None
]:
with display:
for chunk in generator:
Expand All @@ -1254,7 +1255,7 @@ async def stream_async(
echo: EchoOptions = "none",
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
) -> AsyncGenerator[str | ContentThinking | ContentToolRequest | ContentToolResult, None]: ...

async def stream_async(
self,
Expand All @@ -1263,7 +1264,7 @@ async def stream_async(
echo: EchoOptions = "none",
data_model: Optional[type[BaseModel]] = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]:
) -> AsyncGenerator[str | ContentThinking | ContentToolRequest | ContentToolResult, None]:
"""
Generate a response from the chat in a streaming fashion asynchronously.

Expand Down Expand Up @@ -1309,9 +1310,12 @@ class Person(BaseModel):


chat = ChatOpenAI()
chunks = [chunk async for chunk in await chat.stream_async(
"John is 25 years old", data_model=Person
)]
chunks = [
chunk
async for chunk in await chat.stream_async(
"John is 25 years old", data_model=Person
)
]
person = Person.model_validate_json("".join(chunks))
```
"""
Expand All @@ -1320,7 +1324,7 @@ class Person(BaseModel):
display = self._markdown_display(echo=echo)

async def wrapper() -> AsyncGenerator[
str | ContentToolRequest | ContentToolResult, None
str | ContentThinking | ContentToolRequest | ContentToolResult, None
]:
with display:
async for chunk in self._chat_impl_async(
Expand Down Expand Up @@ -2481,7 +2485,7 @@ def _chat_impl(
stream: bool,
kwargs: Optional[SubmitInputArgsT] = None,
data_model: Optional[type[BaseModel]] = None,
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]: ...
) -> Generator[str | ContentThinking | ContentToolRequest | ContentToolResult, None, None]: ...

def _chat_impl(
self,
Expand All @@ -2491,7 +2495,7 @@ def _chat_impl(
stream: bool,
kwargs: Optional[SubmitInputArgsT] = None,
data_model: Optional[type[BaseModel]] = None,
) -> Generator[str | ContentToolRequest | ContentToolResult, None, None]:
) -> Generator[str | Content, None, None]:
user_turn_result: UserTurn | None = user_turn
while user_turn_result is not None:
for chunk in self._submit_turns(
Expand All @@ -2500,6 +2504,7 @@ def _chat_impl(
stream=stream,
data_model=data_model,
kwargs=kwargs,
content_mode=content,
):
yield chunk

Expand Down Expand Up @@ -2548,7 +2553,7 @@ def _chat_impl_async(
stream: bool,
kwargs: Optional[SubmitInputArgsT] = None,
data_model: Optional[type[BaseModel]] = None,
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]: ...
) -> AsyncGenerator[str | ContentThinking | ContentToolRequest | ContentToolResult, None]: ...

async def _chat_impl_async(
self,
Expand All @@ -2558,7 +2563,7 @@ async def _chat_impl_async(
stream: bool,
kwargs: Optional[SubmitInputArgsT] = None,
data_model: Optional[type[BaseModel]] = None,
) -> AsyncGenerator[str | ContentToolRequest | ContentToolResult, None]:
) -> AsyncGenerator[str | Content, None]:
user_turn_result: UserTurn | None = user_turn
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
Expand All @@ -2567,6 +2572,7 @@ async def _chat_impl_async(
stream=stream,
data_model=data_model,
kwargs=kwargs,
content_mode=content,
):
yield chunk

Expand Down Expand Up @@ -2597,14 +2603,38 @@ async def _chat_impl_async(
if all_results:
user_turn_result = UserTurn(all_results)

@overload
def _submit_turns(
self,
user_turn: UserTurn,
echo: EchoOptions,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
content_mode: Literal["text"] = "text",
) -> Generator[str, None, None]: ...

@overload
def _submit_turns(
self,
user_turn: UserTurn,
echo: EchoOptions,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
*,
content_mode: Literal["all"],
) -> Generator[str | Content, None, None]: ...

def _submit_turns(
self,
user_turn: UserTurn,
echo: EchoOptions,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
content_mode: Literal["text", "all"] = "text",
) -> Generator[str | Content, None, None]:
if any(isinstance(x, Tool) and x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

Expand All @@ -2630,10 +2660,17 @@ def emit(text: str | Content):

result = None
for chunk in response:
text = self.provider.stream_text(chunk)
if text:
emit(text)
yield text
content = self.provider.stream_content(chunk)
if content is not None:
text = content_text(content)
if text:
emit(text)
if content_mode == "all" and isinstance(
content, ContentThinking
):
yield content
Comment thread
cpsievert marked this conversation as resolved.
else:
yield text
result = self.provider.stream_merge_chunks(result, chunk)

turn = self.provider.stream_turn(
Expand Down Expand Up @@ -2675,14 +2712,38 @@ def emit(text: str | Content):
tokens_log(self.provider, turn.tokens)
self._turns.extend([user_turn, turn])

@overload
def _submit_turns_async(
self,
user_turn: UserTurn,
echo: EchoOptions,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
content_mode: Literal["text"] = "text",
) -> AsyncGenerator[str, None]: ...

@overload
def _submit_turns_async(
self,
user_turn: UserTurn,
echo: EchoOptions,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
*,
content_mode: Literal["all"],
) -> AsyncGenerator[str | Content, None]: ...

async def _submit_turns_async(
self,
user_turn: UserTurn,
echo: EchoOptions,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str, None]:
content_mode: Literal["text", "all"] = "text",
) -> AsyncGenerator[str | Content, None]:
def emit(text: str | Content):
self._echo_content(str(text))

Expand All @@ -2705,10 +2766,17 @@ def emit(text: str | Content):

result = None
async for chunk in response:
text = self.provider.stream_text(chunk)
if text:
emit(text)
yield text
content = self.provider.stream_content(chunk)
if content is not None:
text = content_text(content)
if text:
emit(text)
if content_mode == "all" and isinstance(
content, ContentThinking
):
yield content
Comment thread
cpsievert marked this conversation as resolved.
else:
yield text
result = self.provider.stream_merge_chunks(result, chunk)

turn = self.provider.stream_turn(
Expand Down Expand Up @@ -3184,6 +3252,15 @@ class ToolFailureWarning(RuntimeWarning):
warnings.simplefilter("always", ToolFailureWarning)


def content_text(content: Content) -> str:
"""Extract displayable text from a Content object."""
if isinstance(content, ContentThinking):
return content.thinking
if isinstance(content, ContentText):
return content.text
return str(content)


def is_quarto():
return os.getenv("QUARTO_PYTHON", None) is not None

Expand Down
14 changes: 12 additions & 2 deletions chatlas/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pydantic import BaseModel

from ._content import Content
from ._content import Content, ContentText, ContentThinking
from ._tools import Tool, ToolBuiltIn
from ._turn import AssistantTurn, Turn
from ._typing_extensions import NotRequired, TypedDict
Expand Down Expand Up @@ -226,7 +226,17 @@ async def chat_perform_async(
) -> AsyncIterable[ChatCompletionChunkT] | ChatCompletionT: ...

@abstractmethod
def stream_text(self, chunk: ChatCompletionChunkT) -> Optional[str]: ...
def stream_content(self, chunk: ChatCompletionChunkT) -> Optional["Content"]: ...

def stream_text(self, chunk: ChatCompletionChunkT) -> Optional[str]:
content = self.stream_content(chunk)
if content is None:
return None
if isinstance(content, ContentThinking):
return content.thinking
if isinstance(content, ContentText):
return content.text
return str(content)

@abstractmethod
def stream_merge_chunks(
Expand Down
10 changes: 4 additions & 6 deletions chatlas/_provider_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,12 @@ def _structured_tool_call(**kwargs: Any):

return kwargs_full

def stream_text(self, chunk) -> Optional[str]:
def stream_content(self, chunk) -> Optional[Content]:
if chunk.type == "content_block_delta":
if chunk.delta.type == "text_delta":
return chunk.delta.text
return ContentText.model_construct(text=chunk.delta.text)
if chunk.delta.type == "thinking_delta":
return chunk.delta.thinking
return ContentThinking(thinking=chunk.delta.thinking)
return None

def stream_merge_chunks(self, completion, chunk):
Expand Down Expand Up @@ -830,9 +830,7 @@ def _as_turn(self, completion: Message, has_data_model=False) -> AssistantTurn:
extra = {
"type": content.type,
"tool_use_id": content.tool_use_id,
"content": [
x.model_dump() for x in content.content
]
"content": [x.model_dump() for x in content.content]
if isinstance(content.content, list)
else content.content.model_dump(),
}
Expand Down
24 changes: 19 additions & 5 deletions chatlas/_provider_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ContentJson,
ContentPDF,
ContentText,
ContentThinking,
ContentToolRequest,
ContentToolResult,
)
Expand Down Expand Up @@ -361,12 +362,23 @@ def _chat_perform_args(

return kwargs_full

def stream_text(self, chunk) -> Optional[str]:
try:
# Errors if there is no text (e.g., tool request)
return chunk.text
except Exception:
def stream_content(self, chunk) -> Optional[Content]:
candidates = getattr(chunk, "candidates", None)
if not candidates:
return None
content = candidates[0].content
if content is None:
return None
parts = content.parts
if not parts:
return None
part = parts[0]
text = getattr(part, "text", None)
if text is None:
return None
if getattr(part, "thought", None):
return ContentThinking(thinking=text)
return ContentText.model_construct(text=text)

def stream_merge_chunks(self, completion, chunk):
chunkd = chunk.model_dump()
Expand Down Expand Up @@ -554,6 +566,8 @@ def _as_turn(
if text:
if has_data_model:
contents.append(ContentJson(value=orjson.loads(text)))
elif part.get("thought"):
contents.append(ContentThinking(thinking=text))
else:
contents.append(ContentText(text=text))
function_call = part.get("function_call")
Expand Down
Loading
Loading