Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/mcp_interviewer/statistics/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

import tiktoken
from mcp.types import Tool
from openai.types.chat import ChatCompletionTool
from openai.types.shared import FunctionDefinition
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params import FunctionDefinition

from .base import CompositeStatistic, ServerScoreCard, Statistic, StatisticValue

logger = logging.getLogger(__name__)


def num_tokens_for_tool(tool: ChatCompletionTool, model):
def num_tokens_for_tool(tool: ChatCompletionToolParam, model):
"""From https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb"""

# Initialize function settings to 0
Expand Down Expand Up @@ -54,9 +54,9 @@ def num_tokens_for_tool(tool: ChatCompletionTool, model):

func_token_count = 0
func_token_count += func_init # Add tokens for start of each function
function = tool.function
f_name = function.name
f_desc = function.description or ""
function = tool["function"]
f_name = function["name"]
f_desc = function.get("description", "") or ""
if f_desc.endswith("."):
f_desc = f_desc[:-1]
line = f_name + ":" + f_desc
Expand All @@ -65,8 +65,9 @@ def num_tokens_for_tool(tool: ChatCompletionTool, model):
) # Add tokens for set name and description

properties = {}
if function.parameters is not None and "properties" in function.parameters:
properties = cast(dict, function.parameters["properties"])
function_parameters = function.get("parameters")
if function_parameters is not None and "properties" in function_parameters:
properties = cast(dict, function_parameters["properties"])

if len(properties) > 0:
func_token_count += prop_init # Add tokens for start of each property
Expand Down Expand Up @@ -100,11 +101,11 @@ def compute(self, server: ServerScoreCard) -> Generator[StatisticValue, None, No

class ToolInputSchemaTokenCount(ToolStatistic):
def compute_tool(self, tool: Tool) -> Generator[StatisticValue, None, None]:
oai_tool = ChatCompletionTool(
oai_tool = ChatCompletionToolParam(
type="function",
function=FunctionDefinition(
name=tool.name,
description=tool.description,
description=tool.description or "",
parameters=tool.inputSchema,
),
)
Expand Down Expand Up @@ -135,7 +136,7 @@ class ToolInputSchemaMaxDepthCount(ToolStatistic):
def compute_tool(self, tool: Tool) -> Generator[StatisticValue, None, None]:
def get_max_depth(o, depth=0):
max_depth = depth
if isinstance(o, list | tuple):
if isinstance(o, (list, tuple)): # noqa: UP038
for item in o:
max_depth = max(get_max_depth(item, depth + 1), max_depth)
elif isinstance(o, dict):
Expand Down