22import json
33import re
44from collections .abc import Callable , Sequence
5- from typing import Any
5+ from typing import Any , TypedDict
66
77import pytest
8- from typing_extensions import override
8+ from typing_extensions import NotRequired , override
99
1010from langchain_core .language_models .fake_chat_models import FakeChatModel
1111from langchain_core .messages import (
@@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None:
135135 assert messages == messages_model_copy
136136
137137
138+ class FilterFields (TypedDict ):
139+ include_names : NotRequired [Sequence [str ]]
140+ exclude_names : NotRequired [Sequence [str ]]
141+ include_types : NotRequired [Sequence [str | type [BaseMessage ]]]
142+ exclude_types : NotRequired [Sequence [str | type [BaseMessage ]]]
143+ include_ids : NotRequired [Sequence [str ]]
144+ exclude_ids : NotRequired [Sequence [str ]]
145+ exclude_tool_calls : NotRequired [Sequence [str ] | bool ]
146+
147+
138148@pytest .mark .parametrize (
139149 "filters" ,
140150 [
@@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None:
153163 {"include_names" : ["blah" , "blur" ], "exclude_types" : [SystemMessage ]},
154164 ],
155165)
156- def test_filter_message (filters : dict ) -> None :
166+ def test_filter_message (filters : FilterFields ) -> None :
157167 messages = [
158168 SystemMessage ("foo" , name = "blah" , id = "1" ),
159169 HumanMessage ("bar" , name = "blur" , id = "2" ),
@@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None:
192202 assert expected == actual
193203
194204 # test explicitly excluding all tool calls
195- actual = filter_messages (messages , exclude_tool_calls = { "1" , "2" } )
205+ actual = filter_messages (messages , exclude_tool_calls = [ "1" , "2" ] )
196206 assert expected == actual
197207
198208 # test excluding a specific tool call
@@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None:
234244 assert expected == actual
235245
236246 # test explicitly excluding all tool calls
237- actual = filter_messages (messages , exclude_tool_calls = { "1" , "2" } )
247+ actual = filter_messages (messages , exclude_tool_calls = [ "1" , "2" ] )
238248 assert expected == actual
239249
240250 # test excluding a specific tool call
@@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None:
508518
509519def test_trim_messages_bound_model_token_counter () -> None :
510520 trimmer = trim_messages (
511- max_tokens = 10 , token_counter = FakeTokenCountingModel ().bind (foo = "bar" )
521+ max_tokens = 10 ,
522+ token_counter = FakeTokenCountingModel ().bind (foo = "bar" ), # type: ignore[call-overload]
512523 )
513524 trimmer .invoke ([HumanMessage ("foobar" )])
514525
515526
516527def test_trim_messages_bad_token_counter () -> None :
517- trimmer = trim_messages (max_tokens = 10 , token_counter = {})
528+ trimmer = trim_messages (max_tokens = 10 , token_counter = {}) # type: ignore[call-overload]
518529 with pytest .raises (
519530 ValueError ,
520531 match = re .escape (
@@ -608,7 +619,9 @@ def count_text_length(msgs: list[BaseMessage]) -> int:
608619
609620 assert len (result ) == 1
610621 assert len (result [0 ].content ) == 1
611- assert result [0 ].content [0 ]["text" ] == "First part of text."
622+ content = result [0 ].content [0 ]
623+ assert isinstance (content , dict )
624+ assert content ["text" ] == "First part of text."
612625 assert messages == messages_copy
613626
614627
0 commit comments