Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 87 additions & 82 deletions lib/crewai/src/crewai/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
from collections.abc import Callable
from inspect import signature
from typing import Any, cast, get_args, get_origin
from typing import Any, get_args, get_origin

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -55,7 +55,7 @@ class _ArgsSchemaPlaceholder(PydanticBaseModel):
default=False, description="Flag to check if the description has been updated."
)

cache_function: Callable = Field(
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
Expand All @@ -80,20 +80,21 @@ def _default_args_schema(
if v != cls._ArgsSchemaPlaceholder:
return v

return cast(
type[PydanticBaseModel],
type(
f"{cls.__name__}Schema",
(PydanticBaseModel,),
{
"__annotations__": {
k: v
for k, v in cls._run.__annotations__.items()
if k != "return"
},
},
),
)
run_sig = signature(cls._run)
fields: dict[str, Any] = {}

for param_name, param in run_sig.parameters.items():
if param_name in ("self", "return"):
continue

annotation = param.annotation if param.annotation != param.empty else Any

if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)

return create_model(f"{cls.__name__}Schema", **fields)

@field_validator("max_usage_count", mode="before")
@classmethod
Expand Down Expand Up @@ -164,24 +165,21 @@ def from_langchain(cls, tool: Any) -> BaseTool:
args_schema = getattr(tool, "args_schema", None)

if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields: dict[str, Any] = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
fields: dict[str, Any] = {}
for name, param in func_signature.parameters.items():
if name == "self":
continue
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[name] = (param_annotation, ...)
else:
fields[name] = (param_annotation, param.default)
if fields:
args_schema = create_model(f"{tool.name}Input", **fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
Expand All @@ -195,20 +193,24 @@ def from_langchain(cls, tool: Any) -> BaseTool:

def _set_args_schema(self) -> None:
if self.args_schema is None:
class_name = f"{self.__class__.__name__}Schema"
self.args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v
for k, v in self._run.__annotations__.items()
if k != "return"
},
},
),
run_sig = signature(self._run)
fields: dict[str, Any] = {}

for param_name, param in run_sig.parameters.items():
if param_name in ("self", "return"):
continue

annotation = (
param.annotation if param.annotation != param.empty else Any
)

if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)

self.args_schema = create_model(
f"{self.__class__.__name__}Schema", **fields
)

def _generate_description(self) -> None:
Expand Down Expand Up @@ -241,13 +243,13 @@ def _get_arg_annotations(annotation: type[Any] | None) -> str:
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"

return origin.__name__
return str(origin.__name__)


class Tool(BaseTool):
"""The function that will be executed when the tool is called."""

func: Callable
func: Callable[..., Any]

def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
Expand Down Expand Up @@ -275,24 +277,21 @@ def from_langchain(cls, tool: Any) -> Tool:
args_schema = getattr(tool, "args_schema", None)

if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields: dict[str, Any] = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
fields: dict[str, Any] = {}
for name, param in func_signature.parameters.items():
if name == "self":
continue
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[name] = (param_annotation, ...)
else:
fields[name] = (param_annotation, param.default)
if fields:
args_schema = create_model(f"{tool.name}Input", **fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
Expand All @@ -312,37 +311,43 @@ def to_langchain(


def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None
) -> Callable:
"""
Decorator to create a tool from a function.
*args: Callable[..., Any] | str,
result_as_answer: bool = False,
max_usage_count: int | None = None,
) -> Callable[..., Any] | BaseTool:
"""Decorator to create a tool from a function.

Args:
*args: Positional arguments, either the function to decorate or the tool name.
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
"""

def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], BaseTool]:
def _make_tool(f: Callable[..., Any]) -> BaseTool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
raise ValueError("Function must have type annotations")

func_sig = signature(f)
fields: dict[str, Any] = {}

for param_name, param in func_sig.parameters.items():
if param_name == "return":
continue

annotation = (
param.annotation if param.annotation != param.empty else Any
)

if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)

class_name = "".join(tool_name.split()).title()
args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
},
},
),
)
args_schema = create_model(class_name, **fields)

return Tool(
name=tool_name,
Expand Down
Loading
Loading