Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import logging
import math
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Concatenate,
Literal,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)
Expand Down Expand Up @@ -384,33 +388,54 @@ def convert_to_messages(
return [_convert_to_message(m) for m in messages]


def _runnable_support(func: Callable) -> Callable:
@overload
def wrapped(
messages: None = None, **kwargs: Any
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
_P = ParamSpec("_P")
_R_co = TypeVar("_R_co", covariant=True)


class _RunnableSupportCallable(Protocol[_P, _R_co]):
@overload
def wrapped(
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> list[BaseMessage]: ...
def __call__(
self,
messages: None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...

@overload
def __call__(
self,
messages: Sequence[MessageLikeRepresentation] | PromptValue,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R_co: ...

def __call__(
self,
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...


def _runnable_support(
func: Callable[
Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co
],
) -> _RunnableSupportCallable[_P, _R_co]:
@wraps(func)
def wrapped(
messages: Sequence[MessageLikeRepresentation] | None = None,
**kwargs: Any,
) -> (
list[BaseMessage]
| Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
):
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]:
# Import locally to prevent circular import.
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415

if messages is not None:
return func(messages, **kwargs)
return func(messages, *args, **kwargs)
return RunnableLambda(partial(func, **kwargs), name=func.__name__)

wrapped.__doc__ = func.__doc__
return wrapped
return cast("_RunnableSupportCallable[_P, _R_co]", wrapped)


@_runnable_support
Expand Down
29 changes: 21 additions & 8 deletions libs/core/tests/unit_tests/messages/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, TypedDict

import pytest
from typing_extensions import override
from typing_extensions import NotRequired, override

from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
Expand Down Expand Up @@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None:
assert messages == messages_model_copy


class FilterFields(TypedDict):
include_names: NotRequired[Sequence[str]]
exclude_names: NotRequired[Sequence[str]]
include_types: NotRequired[Sequence[str | type[BaseMessage]]]
exclude_types: NotRequired[Sequence[str | type[BaseMessage]]]
include_ids: NotRequired[Sequence[str]]
exclude_ids: NotRequired[Sequence[str]]
exclude_tool_calls: NotRequired[Sequence[str] | bool]


@pytest.mark.parametrize(
"filters",
[
Expand All @@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None:
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
],
)
def test_filter_message(filters: dict) -> None:
def test_filter_message(filters: FilterFields) -> None:
messages = [
SystemMessage("foo", name="blah", id="1"),
HumanMessage("bar", name="blur", id="2"),
Expand Down Expand Up @@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None:
assert expected == actual

# test explicitly excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
Copy link
Collaborator Author

@cbornet cbornet Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It found some incorrect typing 😃

assert expected == actual

# test excluding a specific tool call
Expand Down Expand Up @@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None:
assert expected == actual

# test explicitly excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
assert expected == actual

# test excluding a specific tool call
Expand Down Expand Up @@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None:

def test_trim_messages_bound_model_token_counter() -> None:
trimmer = trim_messages(
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
max_tokens=10,
token_counter=FakeTokenCountingModel().bind(foo="bar"), # type: ignore[call-overload]
)
trimmer.invoke([HumanMessage("foobar")])


def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={})
trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is now typing for _runnable_support annotated method args.

with pytest.raises(
ValueError,
match=re.escape(
Expand Down Expand Up @@ -608,7 +619,9 @@ def count_text_length(msgs: list[BaseMessage]) -> int:

assert len(result) == 1
assert len(result[0].content) == 1
assert result[0].content[0]["text"] == "First part of text."
content = result[0].content[0]
assert isinstance(content, dict)
assert content["text"] == "First part of text."
assert messages == messages_copy


Expand Down
19 changes: 11 additions & 8 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,14 +516,17 @@ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMess
try:
if self.trim_tokens_to_summarize is None:
return messages
return trim_messages(
messages,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
return cast(
"list[AnyMessage]",
trim_messages(
messages,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
),
)
except Exception: # noqa: BLE001
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
4 changes: 2 additions & 2 deletions libs/langchain_v1/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.