|
20 | 20 | TYPE_CHECKING, |
21 | 21 | Annotated, |
22 | 22 | Any, |
| 23 | + Concatenate, |
23 | 24 | Literal, |
| 25 | + ParamSpec, |
| 26 | + TypeVar, |
24 | 27 | cast, |
25 | 28 | overload, |
26 | 29 | ) |
|
46 | 49 | if TYPE_CHECKING: |
47 | 50 | from langchain_core.language_models import BaseLanguageModel |
48 | 51 | from langchain_core.prompt_values import PromptValue |
49 | | - from langchain_core.runnables.base import Runnable |
| 52 | + from langchain_core.runnables.base import RunnableLambda |
50 | 53 |
|
51 | 54 | try: |
52 | 55 | from langchain_text_splitters import TextSplitter |
@@ -384,36 +387,64 @@ def convert_to_messages( |
384 | 387 | return [_convert_to_message(m) for m in messages] |
385 | 388 |
|
386 | 389 |
|
387 | | -def _runnable_support(func: Callable) -> Callable: |
388 | | - @overload |
389 | | - def wrapped( |
390 | | - messages: None = None, **kwargs: Any |
391 | | - ) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ... |
| 390 | +P = ParamSpec("P") # Parameters of the decorated function |
| 391 | +R = TypeVar("R") # Return type of the decorated function |
392 | 392 |
|
393 | | - @overload |
394 | | - def wrapped( |
395 | | - messages: Sequence[MessageLikeRepresentation], **kwargs: Any |
396 | | - ) -> list[BaseMessage]: ... |
397 | 393 |
|
| 394 | +def _runnable_support( |
| 395 | + func: Callable[Concatenate[Iterable[MessageLikeRepresentation], P], R], |
| 396 | +) -> Callable[ |
| 397 | + Concatenate[Iterable[MessageLikeRepresentation] | None, P], R | RunnableLambda |
| 398 | +]: |
| 399 | + # @wraps(func) |
398 | 400 | def wrapped( |
399 | | - messages: Sequence[MessageLikeRepresentation] | None = None, |
400 | | - **kwargs: Any, |
401 | | - ) -> ( |
402 | | - list[BaseMessage] |
403 | | - | Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]] |
404 | | - ): |
| 401 | + messages: Iterable[MessageLikeRepresentation] | None = None, |
| 402 | + /, |
| 403 | + *args: P.args, |
| 404 | + **kwargs: P.kwargs, |
| 405 | + ) -> R | RunnableLambda: |
405 | 406 | # Import locally to prevent circular import. |
406 | 407 | from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415 |
407 | 408 |
|
408 | 409 | if messages is not None: |
409 | | - return func(messages, **kwargs) |
| 410 | + return func(messages, *args, **kwargs) |
410 | 411 | return RunnableLambda(partial(func, **kwargs), name=func.__name__) |
411 | 412 |
|
412 | 413 | wrapped.__doc__ = func.__doc__ |
413 | 414 | return wrapped |
414 | 415 |
|
415 | 416 |
|
416 | | -@_runnable_support |
| 417 | +@overload |
| 418 | +def filter_messages( |
| 419 | + messages: Iterable[MessageLikeRepresentation] | PromptValue, |
| 420 | + *, |
| 421 | + include_names: Sequence[str] | None = None, |
| 422 | + exclude_names: Sequence[str] | None = None, |
| 423 | + include_types: Sequence[str | type[BaseMessage]] | None = None, |
| 424 | + exclude_types: Sequence[str | type[BaseMessage]] | None = None, |
| 425 | + include_ids: Sequence[str] | None = None, |
| 426 | + exclude_ids: Sequence[str] | None = None, |
| 427 | + exclude_tool_calls: Sequence[str] | bool | None = None, |
| 428 | +) -> list[BaseMessage]: ... |
| 429 | + |
| 430 | + |
| 431 | +@overload |
| 432 | +def filter_messages( |
| 433 | + messages: None = None, |
| 434 | + *, |
| 435 | + include_names: Sequence[str] | None = None, |
| 436 | + exclude_names: Sequence[str] | None = None, |
| 437 | + include_types: Sequence[str | type[BaseMessage]] | None = None, |
| 438 | + exclude_types: Sequence[str | type[BaseMessage]] | None = None, |
| 439 | + include_ids: Sequence[str] | None = None, |
| 440 | + exclude_ids: Sequence[str] | None = None, |
| 441 | + exclude_tool_calls: Sequence[str] | bool | None = None, |
| 442 | +) -> RunnableLambda[ |
| 443 | + Iterable[MessageLikeRepresentation] | PromptValue, list[BaseMessage] |
| 444 | +]: ... |
| 445 | + |
| 446 | + |
| 447 | +@_runnable_support # type: ignore[misc] |
417 | 448 | def filter_messages( |
418 | 449 | messages: Iterable[MessageLikeRepresentation] | PromptValue, |
419 | 450 | *, |
@@ -557,7 +588,25 @@ def filter_messages( |
557 | 588 | return filtered |
558 | 589 |
|
559 | 590 |
|
560 | | -@_runnable_support |
| 591 | +@overload |
| 592 | +def merge_message_runs( |
| 593 | + messages: Iterable[MessageLikeRepresentation] | PromptValue, |
| 594 | + *, |
| 595 | + chunk_separator: str = "\n", |
| 596 | +) -> list[BaseMessage]: ... |
| 597 | + |
| 598 | + |
| 599 | +@overload |
| 600 | +def merge_message_runs( |
| 601 | + messages: None = None, |
| 602 | + *, |
| 603 | + chunk_separator: str = "\n", |
| 604 | +) -> RunnableLambda[ |
| 605 | + Iterable[MessageLikeRepresentation] | PromptValue, list[BaseMessage] |
| 606 | +]: ... |
| 607 | + |
| 608 | + |
| 609 | +@_runnable_support # type: ignore[misc] |
561 | 610 | def merge_message_runs( |
562 | 611 | messages: Iterable[MessageLikeRepresentation] | PromptValue, |
563 | 612 | *, |
@@ -686,9 +735,45 @@ def merge_message_runs( |
686 | 735 | return merged |
687 | 736 |
|
688 | 737 |
|
| 738 | +@overload |
| 739 | +def trim_messages( |
| 740 | + messages: Iterable[MessageLikeRepresentation] | PromptValue, |
| 741 | + *, |
| 742 | + max_tokens: int, |
| 743 | + token_counter: Callable[[list[BaseMessage]], int] |
| 744 | + | Callable[[BaseMessage], int] |
| 745 | + | BaseLanguageModel, |
| 746 | + strategy: Literal["first", "last"] = "last", |
| 747 | + allow_partial: bool = False, |
| 748 | + end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, |
| 749 | + start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, |
| 750 | + include_system: bool = False, |
| 751 | + text_splitter: Callable[[str], list[str]] | TextSplitter | None = None, |
| 752 | +) -> list[BaseMessage]: ... |
| 753 | + |
| 754 | + |
| 755 | +@overload |
| 756 | +def trim_messages( |
| 757 | + messages: None = None, |
| 758 | + *, |
| 759 | + max_tokens: int, |
| 760 | + token_counter: Callable[[list[BaseMessage]], int] |
| 761 | + | Callable[[BaseMessage], int] |
| 762 | + | BaseLanguageModel, |
| 763 | + strategy: Literal["first", "last"] = "last", |
| 764 | + allow_partial: bool = False, |
| 765 | + end_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, |
| 766 | + start_on: str | type[BaseMessage] | Sequence[str | type[BaseMessage]] | None = None, |
| 767 | + include_system: bool = False, |
| 768 | + text_splitter: Callable[[str], list[str]] | TextSplitter | None = None, |
| 769 | +) -> RunnableLambda[ |
| 770 | + Iterable[MessageLikeRepresentation] | PromptValue, list[BaseMessage] |
| 771 | +]: ... |
| 772 | + |
| 773 | + |
689 | 774 | # TODO: Update so validation errors (for token_counter, for example) are raised on |
690 | 775 | # init not at runtime. |
691 | | -@_runnable_support |
| 776 | +@_runnable_support # type: ignore[misc] |
692 | 777 | def trim_messages( |
693 | 778 | messages: Iterable[MessageLikeRepresentation] | PromptValue, |
694 | 779 | *, |
|
0 commit comments