Skip to content

Commit d06f642

Browse files
committed
chore(core): improve typing of messages utils functions
1 parent 78c10f8 commit d06f642

File tree

2 files changed

+125
-27
lines changed

2 files changed

+125
-27
lines changed

libs/core/langchain_core/messages/utils.py

Lines changed: 105 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
TYPE_CHECKING,
2121
Annotated,
2222
Any,
23+
Concatenate,
2324
Literal,
25+
ParamSpec,
26+
TypeVar,
2427
cast,
2528
overload,
2629
)
@@ -46,7 +49,7 @@
4649
if TYPE_CHECKING:
4750
from langchain_core.language_models import BaseLanguageModel
4851
from langchain_core.prompt_values import PromptValue
49-
from langchain_core.runnables.base import Runnable
52+
from langchain_core.runnables.base import RunnableLambda
5053

5154
try:
5255
from langchain_text_splitters import TextSplitter
@@ -384,36 +387,64 @@ def convert_to_messages(
384387
return [_convert_to_message(m) for m in messages]
385388

386389

387-
def _runnable_support(func: Callable) -> Callable:
388-
@overload
389-
def wrapped(
390-
messages: None = None, **kwargs: Any
391-
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
390+
P = ParamSpec("P") # Parameters of the decorated function
391+
R = TypeVar("R") # Return type of the decorated function
392392

393-
@overload
394-
def wrapped(
395-
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
396-
) -> list[BaseMessage]: ...
397393

394+
def _runnable_support(
395+
func: Callable[Concatenate[Iterable[MessageLikeRepresentation], P], R],
396+
) -> Callable[
397+
Concatenate[Iterable[MessageLikeRepresentation] | None, P], R | RunnableLambda
398+
]:
399+
# @wraps(func)
398400
def wrapped(
399-
messages: Sequence[MessageLikeRepresentation] | None = None,
400-
**kwargs: Any,
401-
) -> (
402-
list[BaseMessage]
403-
| Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
404-
):
401+
messages: Iterable[MessageLikeRepresentation] | None = None,
402+
/,
403+
*args: P.args,
404+
**kwargs: P.kwargs,
405+
) -> R | RunnableLambda:
405406
# Import locally to prevent circular import.
406407
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
407408

408409
if messages is not None:
409-
return func(messages, **kwargs)
410+
return func(messages, *args, **kwargs)
410411
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
411412

412413
wrapped.__doc__ = func.__doc__
413414
return wrapped
414415

415416

416-
@_runnable_support
417+
@overload
418+
def filter_messages(
419+
messages: Iterable[MessageLikeRepresentation] | PromptValue,
420+
*,
421+
include_names: Sequence[str] | None = None,
422+
exclude_names: Sequence[str] | None = None,
423+
include_types: Sequence[str | type[BaseMessage]] | None = None,
424+
exclude_types: Sequence[str | type[BaseMessage]] | None = None,
425+
include_ids: Sequence[str] | None = None,
426+
exclude_ids: Sequence[str] | None = None,
427+
exclude_tool_calls: Sequence[str] | bool | None = None,
428+
) -> list[BaseMessage]: ...
429+
430+
431+
@overload
432+
def filter_messages(
433+
messages: None = None,
434+
*,
435+
include_names: Sequence[str] | None = None,
436+
exclude_names: Sequence[str] | None = None,
437+
include_types: Sequence[str | type[BaseMessage]] | None = None,
438+
exclude_types: Sequence[str | type[BaseMessage]] | None = None,
439+
include_ids: Sequence[str] | None = None,
440+
exclude_ids: Sequence[str] | None = None,
441+
exclude_tool_calls: Sequence[str] | bool | None = None,
442+
) -> RunnableLambda[
443+
Iterable[MessageLikeRepresentation] | PromptValue, list[BaseMessage]
444+
]: ...
445+
446+
447+
@_runnable_support # type: ignore[misc]
417448
def filter_messages(
418449
messages: Iterable[MessageLikeRepresentation] | PromptValue,
419450
*,
@@ -557,7 +588,25 @@ def filter_messages(
557588
return filtered
558589

559590

560-
@_runnable_support
591+
@overload
592+
def merge_message_runs(
593+
messages: Iterable[MessageLikeRepresentation] | PromptValue,
594+
*,
595+
chunk_separator: str = "\n",
596+
) -> list[BaseMessage]: ...
597+
598+
599+
@overload
600+
def merge_message_runs(
601+
messages: None = None,
602+
*,
603+
chunk_separator: str = "\n",
604+
) -> RunnableLambda[
605+
Iterable[MessageLikeRepresentation] | PromptValue, list[BaseMessage]
606+
]: ...
607+
608+
609+
@_runnable_support # type: ignore[misc]
561610
def merge_message_runs(
562611
messages: Iterable[MessageLikeRepresentation] | PromptValue,
563612
*,
@@ -686,9 +735,45 @@ def merge_message_runs(
686735
return merged
687736

688737

738+
@overload
739+
def trim_messages(
740+
messages: Iterable[MessageLikeRepresentation] | PromptValue,
741+
*,
742+
max_tokens: int,
743+
token_counter: Callable[[list[BaseMessage]], int]
744+
| Callable[[BaseMessage], int]
745+
| BaseLanguageModel,
746+
strategy: Literal["first", "last"] = "last",
747+
allow_partial: bool = False,
748+
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
749+
start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
750+
include_system: bool = False,
751+
text_splitter: Callable[[str], list[str]] | TextSplitter | None = None,
752+
) -> list[BaseMessage]: ...
753+
754+
755+
@overload
756+
def trim_messages(
757+
messages: None = None,
758+
*,
759+
max_tokens: int,
760+
token_counter: Callable[[list[BaseMessage]], int]
761+
| Callable[[BaseMessage], int]
762+
| BaseLanguageModel,
763+
strategy: Literal["first", "last"] = "last",
764+
allow_partial: bool = False,
765+
end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
766+
start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None,
767+
include_system: bool = False,
768+
text_splitter: Callable[[str], list[str]] | TextSplitter | None = None,
769+
) -> RunnableLambda[
770+
Iterable[MessageLikeRepresentation] | PromptValue, list[BaseMessage]
771+
]: ...
772+
773+
689774
# TODO: Update so validation errors (for token_counter, for example) are raised on
690775
# init not at runtime.
691-
@_runnable_support
776+
@_runnable_support # type: ignore[misc]
692777
def trim_messages(
693778
messages: Iterable[MessageLikeRepresentation] | PromptValue,
694779
*,

libs/core/tests/unit_tests/messages/test_utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import re
44
from collections.abc import Callable, Sequence
5-
from typing import Any
5+
from typing import Any, TypedDict
66

77
import pytest
88
from typing_extensions import override
@@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None:
135135
assert messages == messages_model_copy
136136

137137

138+
class FilterFields(TypedDict):
139+
include_names: Sequence[str] | None
140+
exclude_names: Sequence[str] | None
141+
include_types: Sequence[str | type[BaseMessage]] | None
142+
exclude_types: Sequence[str | type[BaseMessage]] | None
143+
include_ids: Sequence[str] | None
144+
exclude_ids: Sequence[str] | None
145+
exclude_tool_calls: Sequence[str] | bool | None
146+
147+
138148
@pytest.mark.parametrize(
139149
"filters",
140150
[
@@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None:
153163
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
154164
],
155165
)
156-
def test_filter_message(filters: dict) -> None:
166+
def test_filter_message(filters: FilterFields) -> None:
157167
messages = [
158168
SystemMessage("foo", name="blah", id="1"),
159169
HumanMessage("bar", name="blur", id="2"),
@@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None:
192202
assert expected == actual
193203

194204
# test explicitly excluding all tool calls
195-
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
205+
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
196206
assert expected == actual
197207

198208
# test excluding a specific tool call
@@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None:
234244
assert expected == actual
235245

236246
# test explicitly excluding all tool calls
237-
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
247+
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
238248
assert expected == actual
239249

240250
# test excluding a specific tool call
@@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None:
508518

509519
def test_trim_messages_bound_model_token_counter() -> None:
510520
trimmer = trim_messages(
511-
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
521+
max_tokens=10,
522+
token_counter=FakeTokenCountingModel().bind(foo="bar"), # type: ignore[call-overload]
512523
)
513524
trimmer.invoke([HumanMessage("foobar")])
514525

515526

516527
def test_trim_messages_bad_token_counter() -> None:
517-
trimmer = trim_messages(max_tokens=10, token_counter={})
528+
trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]
518529
with pytest.raises(
519530
ValueError,
520531
match=re.escape(
@@ -608,7 +619,9 @@ def count_text_length(msgs: list[BaseMessage]) -> int:
608619

609620
assert len(result) == 1
610621
assert len(result[0].content) == 1
611-
assert result[0].content[0]["text"] == "First part of text."
622+
content = result[0].content[0]
623+
assert isinstance(content, dict)
624+
assert content["text"] == "First part of text."
612625
assert messages == messages_copy
613626

614627

0 commit comments

Comments
 (0)