diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index 22f06f86dffcd..88e842c034e1d 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -877,6 +877,7 @@ def run( name=run_name, run_id=run_id, inputs=filtered_tool_input, + tool_call_id=tool_call_id, **kwargs, ) @@ -930,7 +931,9 @@ def run( run_manager.on_tool_error(error_to_raise) raise error_to_raise output = _format_output(content, artifact, tool_call_id, self.name, status) - run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) + run_manager.on_tool_end( + output, color=color, name=self.name, tool_call_id=tool_call_id, **kwargs + ) return output async def arun( @@ -1004,6 +1007,7 @@ async def arun( name=run_name, run_id=run_id, inputs=filtered_tool_input, + tool_call_id=tool_call_id, **kwargs, ) content = None @@ -1060,7 +1064,9 @@ async def arun( raise error_to_raise output = _format_output(content, artifact, tool_call_id, self.name, status) - await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs) + await run_manager.on_tool_end( + output, color=color, name=self.name, tool_call_id=tool_call_id, **kwargs + ) return output diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index babb48175265f..99dd3f58732b8 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -30,6 +30,7 @@ AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) +from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.callbacks.manager import ( CallbackManagerForRetrieverRun, ) @@ -3127,3 +3128,147 @@ class MockRuntime: assert captured is not None assert captured == {"query": "test", "limit": 5} assert "runtime" not in captured + + +class ToolCallIdCaptureHandler(BaseCallbackHandler): + """Callback handler that captures tool_call_id from callbacks.""" + + def __init__(self) -> None: + super().__init__() + self.on_tool_start_tool_call_id: str | None = None + self.on_tool_start_kwargs: dict[str, Any] = {} + self.on_tool_end_tool_call_id: str | None = None + self.on_tool_end_kwargs: dict[str, Any] = {} + + def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: Any, + parent_run_id: Any = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + self.on_tool_start_kwargs = kwargs + self.on_tool_start_tool_call_id = kwargs.get("tool_call_id") + + def on_tool_end( + self, + output: Any, + *, + run_id: Any, + parent_run_id: Any = None, + **kwargs: Any, + ) -> None: + self.on_tool_end_kwargs = kwargs + self.on_tool_end_tool_call_id = kwargs.get("tool_call_id") + + +class AsyncToolCallIdCaptureHandler(AsyncCallbackHandler): + """Async callback handler that captures tool_call_id from callbacks.""" + + def __init__(self) -> None: + super().__init__() + self.on_tool_start_tool_call_id: str | None = None + self.on_tool_start_kwargs: dict[str, Any] = {} + self.on_tool_end_tool_call_id: str | None = None + self.on_tool_end_kwargs: dict[str, Any] = {} + + async def on_tool_start( + self, + serialized: dict[str, Any], + input_str: str, + *, + run_id: Any, + parent_run_id: Any = None, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + inputs: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + self.on_tool_start_kwargs = kwargs + self.on_tool_start_tool_call_id = kwargs.get("tool_call_id") + + async def on_tool_end( + self, + output: Any, + *, + run_id: Any, + parent_run_id: Any = None, + **kwargs: Any, + ) -> None: + self.on_tool_end_kwargs = kwargs + self.on_tool_end_tool_call_id = kwargs.get("tool_call_id") + + +def test_tool_call_id_passed_to_callbacks() -> None: + """Test that tool_call_id is passed to on_tool_start and on_tool_end callbacks. + + Regression test for https://github.com/langchain-ai/langchain/issues/34168 + """ + + @tool + def get_weather(location: str) -> str: + """Get weather for a location.""" + return f"Sunny in {location}" + + handler = ToolCallIdCaptureHandler() + + # Invoke with a ToolCall that has an id + tool_call = ToolCall( + name="get_weather", args={"location": "Paris"}, id="call_123", type="tool_call" + ) + result = get_weather.invoke(tool_call, config={"callbacks": [handler]}) + + # Verify the tool executed correctly + assert result.content == "Sunny in Paris" + + # Verify tool_call_id was passed to on_tool_start + assert handler.on_tool_start_tool_call_id == "call_123", ( + f"Expected tool_call_id='call_123' in on_tool_start, " + f"got kwargs={handler.on_tool_start_kwargs}" + ) + + # Verify tool_call_id was passed to on_tool_end + assert handler.on_tool_end_tool_call_id == "call_123", ( + f"Expected tool_call_id='call_123' in on_tool_end, " + f"got kwargs={handler.on_tool_end_kwargs}" + ) + + +async def test_tool_call_id_passed_to_async_callbacks() -> None: + """Test that tool_call_id is passed to async on_tool_start and on_tool_end. + + Regression test for https://github.com/langchain-ai/langchain/issues/34168 + """ + + @tool + async def get_weather(location: str) -> str: + """Get weather for a location.""" + return f"Sunny in {location}" + + handler = AsyncToolCallIdCaptureHandler() + + # Invoke with a ToolCall that has an id + tool_call = ToolCall( + name="get_weather", args={"location": "Paris"}, id="call_456", type="tool_call" + ) + result = await get_weather.ainvoke(tool_call, config={"callbacks": [handler]}) + + # Verify the tool executed correctly + assert result.content == "Sunny in Paris" + + # Verify tool_call_id was passed to on_tool_start + assert handler.on_tool_start_tool_call_id == "call_456", ( + f"Expected tool_call_id='call_456' in on_tool_start, " + f"got kwargs={handler.on_tool_start_kwargs}" + ) + + # Verify tool_call_id was passed to on_tool_end + assert handler.on_tool_end_tool_call_id == "call_456", ( + f"Expected tool_call_id='call_456' in on_tool_end, " + f"got kwargs={handler.on_tool_end_kwargs}" + )