From cb5143615aa9ebc450418ce6a3dcd3c0f2f55ffd Mon Sep 17 00:00:00 2001 From: cbornet Date: Sat, 6 Dec 2025 17:49:21 +0100 Subject: [PATCH] chore(core): improve typing of messages utils functions --- libs/core/langchain_core/messages/utils.py | 61 +++++++++++++------ .../tests/unit_tests/messages/test_utils.py | 29 ++++++--- .../agents/middleware/summarization.py | 19 +++--- libs/langchain_v1/uv.lock | 4 +- 4 files changed, 77 insertions(+), 36 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 035cc0f6bf7ad..95f9221c53468 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -15,12 +15,16 @@ import logging import math from collections.abc import Callable, Iterable, Sequence -from functools import partial +from functools import partial, wraps from typing import ( TYPE_CHECKING, Annotated, Any, + Concatenate, Literal, + ParamSpec, + Protocol, + TypeVar, cast, overload, ) @@ -384,33 +388,54 @@ def convert_to_messages( return [_convert_to_message(m) for m in messages] -def _runnable_support(func: Callable) -> Callable: - @overload - def wrapped( - messages: None = None, **kwargs: Any - ) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ... +_P = ParamSpec("_P") +_R_co = TypeVar("_R_co", covariant=True) + +class _RunnableSupportCallable(Protocol[_P, _R_co]): @overload - def wrapped( - messages: Sequence[MessageLikeRepresentation], **kwargs: Any - ) -> list[BaseMessage]: ... + def __call__( + self, + messages: None = None, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ... + @overload + def __call__( + self, + messages: Sequence[MessageLikeRepresentation] | PromptValue, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _R_co: ... + + def __call__( + self, + messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ... + + +def _runnable_support( + func: Callable[ + Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co + ], +) -> _RunnableSupportCallable[_P, _R_co]: + @wraps(func) def wrapped( - messages: Sequence[MessageLikeRepresentation] | None = None, - **kwargs: Any, - ) -> ( - list[BaseMessage] - | Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]] - ): + messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: # Import locally to prevent circular import. from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415 if messages is not None: - return func(messages, **kwargs) + return func(messages, *args, **kwargs) return RunnableLambda(partial(func, **kwargs), name=func.__name__) - wrapped.__doc__ = func.__doc__ - return wrapped + return cast("_RunnableSupportCallable[_P, _R_co]", wrapped) @_runnable_support diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 6876b23f5b874..0d526e4ded64f 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -2,10 +2,10 @@ import json import re from collections.abc import Callable, Sequence -from typing import Any +from typing import Any, TypedDict import pytest -from typing_extensions import override +from typing_extensions import NotRequired, override from langchain_core.language_models.fake_chat_models import FakeChatModel from langchain_core.messages import ( @@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None: assert messages == messages_model_copy +class FilterFields(TypedDict): + include_names: NotRequired[Sequence[str]] + exclude_names: NotRequired[Sequence[str]] + include_types: NotRequired[Sequence[str | type[BaseMessage]]] + exclude_types: NotRequired[Sequence[str | type[BaseMessage]]] + include_ids: NotRequired[Sequence[str]] + exclude_ids: NotRequired[Sequence[str]] + exclude_tool_calls: NotRequired[Sequence[str] | bool] + + @pytest.mark.parametrize( "filters", [ @@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None: {"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]}, ], ) -def test_filter_message(filters: dict) -> None: +def test_filter_message(filters: FilterFields) -> None: messages = [ SystemMessage("foo", name="blah", id="1"), HumanMessage("bar", name="blur", id="2"), @@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None: assert expected == actual # test explicitly excluding all tool calls - actual = filter_messages(messages, exclude_tool_calls={"1", "2"}) + actual = filter_messages(messages, exclude_tool_calls=["1", "2"]) assert expected == actual # test excluding a specific tool call @@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None: assert expected == actual # test explicitly excluding all tool calls - actual = filter_messages(messages, exclude_tool_calls={"1", "2"}) + actual = filter_messages(messages, exclude_tool_calls=["1", "2"]) assert expected == actual # test excluding a specific tool call @@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None: def test_trim_messages_bound_model_token_counter() -> None: trimmer = trim_messages( - max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar") + max_tokens=10, + token_counter=FakeTokenCountingModel().bind(foo="bar"), # type: ignore[call-overload] ) trimmer.invoke([HumanMessage("foobar")]) def test_trim_messages_bad_token_counter() -> None: - trimmer = trim_messages(max_tokens=10, token_counter={}) + trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload] with pytest.raises( ValueError, match=re.escape( @@ -608,7 +619,9 @@ def count_text_length(msgs: list[BaseMessage]) -> int: assert len(result) == 1 assert len(result[0].content) == 1 - assert result[0].content[0]["text"] == "First part of text." + content = result[0].content[0] + assert isinstance(content, dict) + assert content["text"] == "First part of text." assert messages == messages_copy diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 6055c246a18d0..fd99ef16defee 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -516,14 +516,17 @@ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMess try: if self.trim_tokens_to_summarize is None: return messages - return trim_messages( - messages, - max_tokens=self.trim_tokens_to_summarize, - token_counter=self.token_counter, - start_on="human", - strategy="last", - allow_partial=True, - include_system=True, + return cast( + "list[AnyMessage]", + trim_messages( + messages, + max_tokens=self.trim_tokens_to_summarize, + token_counter=self.token_counter, + start_on="human", + strategy="last", + allow_partial=True, + include_system=True, + ), ) except Exception: # noqa: BLE001 return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:] diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index 7696295b1997f..88f2545d6e64e 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation == 'PyPy'", @@ -2166,7 +2166,7 @@ wheels = [ [[package]] name = "langchain-core" -version = "1.1.0" +version = "1.1.1" source = { editable = "../core" } dependencies = [ { name = "jsonpatch" },