Skip to content

Commit cb51436

Browse files
committed
chore(core): improve typing of messages utils functions
1 parent 78c10f8 commit cb51436

File tree

4 files changed

+77
-36
lines changed

4 files changed

+77
-36
lines changed

libs/core/langchain_core/messages/utils.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
import logging
1616
import math
1717
from collections.abc import Callable, Iterable, Sequence
18-
from functools import partial
18+
from functools import partial, wraps
1919
from typing import (
2020
TYPE_CHECKING,
2121
Annotated,
2222
Any,
23+
Concatenate,
2324
Literal,
25+
ParamSpec,
26+
Protocol,
27+
TypeVar,
2428
cast,
2529
overload,
2630
)
@@ -384,33 +388,54 @@ def convert_to_messages(
384388
return [_convert_to_message(m) for m in messages]
385389

386390

387-
def _runnable_support(func: Callable) -> Callable:
388-
@overload
389-
def wrapped(
390-
messages: None = None, **kwargs: Any
391-
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
391+
_P = ParamSpec("_P")
392+
_R_co = TypeVar("_R_co", covariant=True)
393+
392394

395+
class _RunnableSupportCallable(Protocol[_P, _R_co]):
393396
@overload
394-
def wrapped(
395-
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
396-
) -> list[BaseMessage]: ...
397+
def __call__(
398+
self,
399+
messages: None = None,
400+
*args: _P.args,
401+
**kwargs: _P.kwargs,
402+
) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
397403

404+
@overload
405+
def __call__(
406+
self,
407+
messages: Sequence[MessageLikeRepresentation] | PromptValue,
408+
*args: _P.args,
409+
**kwargs: _P.kwargs,
410+
) -> _R_co: ...
411+
412+
def __call__(
413+
self,
414+
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
415+
*args: _P.args,
416+
**kwargs: _P.kwargs,
417+
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
418+
419+
420+
def _runnable_support(
421+
func: Callable[
422+
Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co
423+
],
424+
) -> _RunnableSupportCallable[_P, _R_co]:
425+
@wraps(func)
398426
def wrapped(
399-
messages: Sequence[MessageLikeRepresentation] | None = None,
400-
**kwargs: Any,
401-
) -> (
402-
list[BaseMessage]
403-
| Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
404-
):
427+
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
428+
*args: _P.args,
429+
**kwargs: _P.kwargs,
430+
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]:
405431
# Import locally to prevent circular import.
406432
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
407433

408434
if messages is not None:
409-
return func(messages, **kwargs)
435+
return func(messages, *args, **kwargs)
410436
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
411437

412-
wrapped.__doc__ = func.__doc__
413-
return wrapped
438+
return cast("_RunnableSupportCallable[_P, _R_co]", wrapped)
414439

415440

416441
@_runnable_support

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import json
33
import re
44
from collections.abc import Callable, Sequence
5-
from typing import Any
5+
from typing import Any, TypedDict
66

77
import pytest
8-
from typing_extensions import override
8+
from typing_extensions import NotRequired, override
99

1010
from langchain_core.language_models.fake_chat_models import FakeChatModel
1111
from langchain_core.messages import (
@@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None:
135135
assert messages == messages_model_copy
136136

137137

138+
class FilterFields(TypedDict):
139+
include_names: NotRequired[Sequence[str]]
140+
exclude_names: NotRequired[Sequence[str]]
141+
include_types: NotRequired[Sequence[str | type[BaseMessage]]]
142+
exclude_types: NotRequired[Sequence[str | type[BaseMessage]]]
143+
include_ids: NotRequired[Sequence[str]]
144+
exclude_ids: NotRequired[Sequence[str]]
145+
exclude_tool_calls: NotRequired[Sequence[str] | bool]
146+
147+
138148
@pytest.mark.parametrize(
139149
"filters",
140150
[
@@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None:
153163
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
154164
],
155165
)
156-
def test_filter_message(filters: dict) -> None:
166+
def test_filter_message(filters: FilterFields) -> None:
157167
messages = [
158168
SystemMessage("foo", name="blah", id="1"),
159169
HumanMessage("bar", name="blur", id="2"),
@@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None:
192202
assert expected == actual
193203

194204
# test explicitly excluding all tool calls
195-
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
205+
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
196206
assert expected == actual
197207

198208
# test excluding a specific tool call
@@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None:
234244
assert expected == actual
235245

236246
# test explicitly excluding all tool calls
237-
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
247+
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
238248
assert expected == actual
239249

240250
# test excluding a specific tool call
@@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None:
508518

509519
def test_trim_messages_bound_model_token_counter() -> None:
510520
trimmer = trim_messages(
511-
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
521+
max_tokens=10,
522+
token_counter=FakeTokenCountingModel().bind(foo="bar"), # type: ignore[call-overload]
512523
)
513524
trimmer.invoke([HumanMessage("foobar")])
514525

515526

516527
def test_trim_messages_bad_token_counter() -> None:
517-
trimmer = trim_messages(max_tokens=10, token_counter={})
528+
trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]
518529
with pytest.raises(
519530
ValueError,
520531
match=re.escape(
@@ -608,7 +619,9 @@ def count_text_length(msgs: list[BaseMessage]) -> int:
608619

609620
assert len(result) == 1
610621
assert len(result[0].content) == 1
611-
assert result[0].content[0]["text"] == "First part of text."
622+
content = result[0].content[0]
623+
assert isinstance(content, dict)
624+
assert content["text"] == "First part of text."
612625
assert messages == messages_copy
613626

614627

libs/langchain_v1/langchain/agents/middleware/summarization.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -516,14 +516,17 @@ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMess
516516
try:
517517
if self.trim_tokens_to_summarize is None:
518518
return messages
519-
return trim_messages(
520-
messages,
521-
max_tokens=self.trim_tokens_to_summarize,
522-
token_counter=self.token_counter,
523-
start_on="human",
524-
strategy="last",
525-
allow_partial=True,
526-
include_system=True,
519+
return cast(
520+
"list[AnyMessage]",
521+
trim_messages(
522+
messages,
523+
max_tokens=self.trim_tokens_to_summarize,
524+
token_counter=self.token_counter,
525+
start_on="human",
526+
strategy="last",
527+
allow_partial=True,
528+
include_system=True,
529+
),
527530
)
528531
except Exception: # noqa: BLE001
529532
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]

libs/langchain_v1/uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)