Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ def run(
name=run_name,
run_id=run_id,
inputs=filtered_tool_input,
tool_call_id=tool_call_id,
**kwargs,
)

Expand Down Expand Up @@ -930,7 +931,9 @@ def run(
run_manager.on_tool_error(error_to_raise)
raise error_to_raise
output = _format_output(content, artifact, tool_call_id, self.name, status)
run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
run_manager.on_tool_end(
output, color=color, name=self.name, tool_call_id=tool_call_id, **kwargs
)
return output

async def arun(
Expand Down Expand Up @@ -1004,6 +1007,7 @@ async def arun(
name=run_name,
run_id=run_id,
inputs=filtered_tool_input,
tool_call_id=tool_call_id,
**kwargs,
)
content = None
Expand Down Expand Up @@ -1060,7 +1064,9 @@ async def arun(
raise error_to_raise

output = _format_output(content, artifact, tool_call_id, self.name, status)
await run_manager.on_tool_end(output, color=color, name=self.name, **kwargs)
await run_manager.on_tool_end(
output, color=color, name=self.name, tool_call_id=tool_call_id, **kwargs
)
return output


Expand Down
145 changes: 145 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.callbacks.manager import (
CallbackManagerForRetrieverRun,
)
Expand Down Expand Up @@ -3127,3 +3128,147 @@ class MockRuntime:
assert captured is not None
assert captured == {"query": "test", "limit": 5}
assert "runtime" not in captured


class ToolCallIdCaptureHandler(BaseCallbackHandler):
"""Callback handler that captures tool_call_id from callbacks."""

def __init__(self) -> None:
super().__init__()
self.on_tool_start_tool_call_id: str | None = None
self.on_tool_start_kwargs: dict[str, Any] = {}
self.on_tool_end_tool_call_id: str | None = None
self.on_tool_end_kwargs: dict[str, Any] = {}

def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: Any,
parent_run_id: Any = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.on_tool_start_kwargs = kwargs
self.on_tool_start_tool_call_id = kwargs.get("tool_call_id")

def on_tool_end(
self,
output: Any,
*,
run_id: Any,
parent_run_id: Any = None,
**kwargs: Any,
) -> None:
self.on_tool_end_kwargs = kwargs
self.on_tool_end_tool_call_id = kwargs.get("tool_call_id")


class AsyncToolCallIdCaptureHandler(AsyncCallbackHandler):
"""Async callback handler that captures tool_call_id from callbacks."""

def __init__(self) -> None:
super().__init__()
self.on_tool_start_tool_call_id: str | None = None
self.on_tool_start_kwargs: dict[str, Any] = {}
self.on_tool_end_tool_call_id: str | None = None
self.on_tool_end_kwargs: dict[str, Any] = {}

async def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: Any,
parent_run_id: Any = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.on_tool_start_kwargs = kwargs
self.on_tool_start_tool_call_id = kwargs.get("tool_call_id")

async def on_tool_end(
self,
output: Any,
*,
run_id: Any,
parent_run_id: Any = None,
**kwargs: Any,
) -> None:
self.on_tool_end_kwargs = kwargs
self.on_tool_end_tool_call_id = kwargs.get("tool_call_id")


def test_tool_call_id_passed_to_callbacks() -> None:
"""Test that tool_call_id is passed to on_tool_start and on_tool_end callbacks.

Regression test for https://github.com/langchain-ai/langchain/issues/34168
"""

@tool
def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"Sunny in {location}"

handler = ToolCallIdCaptureHandler()

# Invoke with a ToolCall that has an id
tool_call = ToolCall(
name="get_weather", args={"location": "Paris"}, id="call_123", type="tool_call"
)
result = get_weather.invoke(tool_call, config={"callbacks": [handler]})

# Verify the tool executed correctly
assert result.content == "Sunny in Paris"

# Verify tool_call_id was passed to on_tool_start
assert handler.on_tool_start_tool_call_id == "call_123", (
f"Expected tool_call_id='call_123' in on_tool_start, "
f"got kwargs={handler.on_tool_start_kwargs}"
)

# Verify tool_call_id was passed to on_tool_end
assert handler.on_tool_end_tool_call_id == "call_123", (
f"Expected tool_call_id='call_123' in on_tool_end, "
f"got kwargs={handler.on_tool_end_kwargs}"
)


async def test_tool_call_id_passed_to_async_callbacks() -> None:
"""Test that tool_call_id is passed to async on_tool_start and on_tool_end.

Regression test for https://github.com/langchain-ai/langchain/issues/34168
"""

@tool
async def get_weather(location: str) -> str:
"""Get weather for a location."""
return f"Sunny in {location}"

handler = AsyncToolCallIdCaptureHandler()

# Invoke with a ToolCall that has an id
tool_call = ToolCall(
name="get_weather", args={"location": "Paris"}, id="call_456", type="tool_call"
)
result = await get_weather.ainvoke(tool_call, config={"callbacks": [handler]})

# Verify the tool executed correctly
assert result.content == "Sunny in Paris"

# Verify tool_call_id was passed to on_tool_start
assert handler.on_tool_start_tool_call_id == "call_456", (
f"Expected tool_call_id='call_456' in on_tool_start, "
f"got kwargs={handler.on_tool_start_kwargs}"
)

# Verify tool_call_id was passed to on_tool_end
assert handler.on_tool_end_tool_call_id == "call_456", (
f"Expected tool_call_id='call_456' in on_tool_end, "
f"got kwargs={handler.on_tool_end_kwargs}"
)