Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 13, 2025

📄 1,207% (12.07x) speedup for ToolManager._validate_backend_tool_arguments in marimo/_server/ai/tools/tool_manager.py

⏱️ Runtime : 8.63 milliseconds 661 microseconds (best of 95 runs)

📝 Explanation and details

The optimization achieves a 1206% speedup by eliminating an expensive operation in the _get_tool method's fallback path when no source is specified.

Key optimization: In the original code, when source=None, the method called self._get_all_tools() which gathered all backend tools AND all MCP tools from an external client. The line profiler shows this single call consumed 99.8% of execution time (21.5ms out of 21.5ms total).

The optimized version replaces this expensive all-tools scan with a two-stage lookup:

  1. First, directly check the backend tools dictionary (self._tools.get(name))
  2. Only if not found, then query MCP tools (self._list_mcp_tools())

Why this works: Most tool lookups likely target backend tools stored in the local dictionary. The optimization leverages this by checking the fast O(1) dictionary lookup first, avoiding the expensive MCP client call in the common case.

Performance impact: The test results show dramatic improvements across all scenarios:

  • Simple successful lookups: 25,000-31,000% faster (sub-microsecond vs hundreds of microseconds)
  • Missing tool scenarios: 168-768% faster
  • Large-scale tests with 500+ parameters: 2,197-10,641% faster

This optimization is particularly valuable if the function is called frequently during tool validation workflows, as it transforms an expensive operation into a fast dictionary lookup for the majority of cases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 50 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import pytest
from marimo._server.ai.tools.tool_manager import ToolManager


# Minimal stubs for dependencies
class ToolDefinition:
    def __init__(self, name, parameters, source="backend"):
        self.name = name
        self.parameters = parameters
        self.source = source

# ----------- UNIT TESTS ------------

@pytest.fixture
def manager():
    """Fixture for a ToolManager with no tools initially."""
    return ToolManager()

# ----------- BASIC TEST CASES ------------

def test_valid_arguments_basic(manager):
    """Test: All required arguments provided, no validator."""
    manager._tools["adder"] = ToolDefinition(
        name="adder",
        parameters={"required": ["a", "b"]}
    )
    args = {"a": 1, "b": 2}
    valid, msg = manager._validate_backend_tool_arguments("adder", args)

def test_missing_required_argument_basic(manager):
    """Test: One required argument missing, no validator."""
    manager._tools["adder"] = ToolDefinition(
        name="adder",
        parameters={"required": ["a", "b"]}
    )
    args = {"a": 1}
    valid, msg = manager._validate_backend_tool_arguments("adder", args)

def test_tool_not_found_basic(manager):
    """Test: Tool does not exist."""
    args = {"a": 1, "b": 2}
    valid, msg = manager._validate_backend_tool_arguments("not_a_tool", args)

def test_no_required_params_basic(manager):
    """Test: Tool with no required parameters, should always pass."""
    manager._tools["noop"] = ToolDefinition(
        name="noop",
        parameters={"required": []}
    )
    args = {}
    valid, msg = manager._validate_backend_tool_arguments("noop", args)

def test_extra_arguments_basic(manager):
    """Test: Extra arguments provided, should still pass."""
    manager._tools["echo"] = ToolDefinition(
        name="echo",
        parameters={"required": ["msg"]}
    )
    args = {"msg": "hello", "extra": 123}
    valid, msg = manager._validate_backend_tool_arguments("echo", args)

# ----------- EDGE TEST CASES ------------

def test_validator_rejects_arguments(manager):
    """Test: Tool-specific validator rejects arguments."""
    def validator(args):
        # Reject if a is not positive
        if args.get("a", 0) <= 0:
            return (False, "Parameter 'a' must be positive")
        return (True, "")
    manager._tools["pos_checker"] = ToolDefinition(
        name="pos_checker",
        parameters={"required": ["a"]}
    )
    manager._validation_functions["pos_checker"] = validator
    args = {"a": -5}
    valid, msg = manager._validate_backend_tool_arguments("pos_checker", args)

def test_validator_accepts_arguments(manager):
    """Test: Tool-specific validator accepts arguments."""
    def validator(args):
        return (True, "")
    manager._tools["pass_checker"] = ToolDefinition(
        name="pass_checker",
        parameters={"required": ["a"]}
    )
    manager._validation_functions["pass_checker"] = validator
    args = {"a": 10}
    valid, msg = manager._validate_backend_tool_arguments("pass_checker", args)

def test_validator_raises_exception(manager):
    """Test: Tool-specific validator raises an exception."""
    def validator(args):
        raise ValueError("Bad validator logic!")
    manager._tools["broken_checker"] = ToolDefinition(
        name="broken_checker",
        parameters={"required": ["a"]}
    )
    manager._validation_functions["broken_checker"] = validator
    args = {"a": 1}
    valid, msg = manager._validate_backend_tool_arguments("broken_checker", args)

def test_validator_returns_none(manager):
    """Test: Tool-specific validator returns None, should fallback to basic validation."""
    def validator(args):
        return None
    manager._tools["none_checker"] = ToolDefinition(
        name="none_checker",
        parameters={"required": ["a"]}
    )
    manager._validation_functions["none_checker"] = validator
    args = {"a": 1}
    valid, msg = manager._validate_backend_tool_arguments("none_checker", args)

def test_required_param_is_none(manager):
    """Test: Required param is present but value is None (should pass, only key presence checked)."""
    manager._tools["none_param"] = ToolDefinition(
        name="none_param",
        parameters={"required": ["a"]}
    )
    args = {"a": None}
    valid, msg = manager._validate_backend_tool_arguments("none_param", args)

def test_arguments_is_empty_dict(manager):
    """Test: Arguments is empty dict, required parameters present."""
    manager._tools["empty_args"] = ToolDefinition(
        name="empty_args",
        parameters={"required": ["foo"]}
    )
    args = {}
    valid, msg = manager._validate_backend_tool_arguments("empty_args", args)

def test_tool_with_no_parameters_key(manager):
    """Test: Tool parameters dict missing 'required' key."""
    manager._tools["no_required"] = ToolDefinition(
        name="no_required",
        parameters={}
    )
    args = {}
    valid, msg = manager._validate_backend_tool_arguments("no_required", args)

def test_tool_with_non_list_required(manager):
    """Test: Tool parameters['required'] is not a list."""
    manager._tools["bad_required"] = ToolDefinition(
        name="bad_required",
        parameters={"required": "notalist"}
    )
    args = {}
    # Should treat as iterable, so each char is a param
    valid, msg = manager._validate_backend_tool_arguments("bad_required", args)

def test_tool_with_required_params_but_arguments_has_extra(manager):
    """Test: Arguments has extra keys, required keys present."""
    manager._tools["extra_args"] = ToolDefinition(
        name="extra_args",
        parameters={"required": ["x", "y"]}
    )
    args = {"x": 1, "y": 2, "z": 3}
    valid, msg = manager._validate_backend_tool_arguments("extra_args", args)

# ----------- LARGE SCALE TEST CASES ------------

def test_large_number_of_required_params(manager):
    """Test: Tool with many required params, all provided."""
    num_params = 500
    param_names = [f"p{i}" for i in range(num_params)]
    manager._tools["big_tool"] = ToolDefinition(
        name="big_tool",
        parameters={"required": param_names}
    )
    args = {name: i for i, name in enumerate(param_names)}
    valid, msg = manager._validate_backend_tool_arguments("big_tool", args)

def test_large_number_of_required_params_missing_one(manager):
    """Test: Tool with many required params, one missing."""
    num_params = 500
    param_names = [f"p{i}" for i in range(num_params)]
    manager._tools["big_tool"] = ToolDefinition(
        name="big_tool",
        parameters={"required": param_names}
    )
    args = {name: i for i, name in enumerate(param_names)}
    # Remove one param
    del args[param_names[-1]]
    valid, msg = manager._validate_backend_tool_arguments("big_tool", args)

def test_large_arguments_dict(manager):
    """Test: Arguments dict has many keys, only some required."""
    manager._tools["sparse_tool"] = ToolDefinition(
        name="sparse_tool",
        parameters={"required": ["foo", "bar"]}
    )
    # 1000 keys, only two required
    args = {f"key{i}": i for i in range(1000)}
    args["foo"] = "x"
    args["bar"] = "y"
    valid, msg = manager._validate_backend_tool_arguments("sparse_tool", args)

def test_large_validator(manager):
    """Test: Validator that checks a large number of arguments."""
    def validator(args):
        # Accept only if all values are even
        for v in args.values():
            if v % 2 != 0:
                return (False, "All values must be even")
        return (True, "")
    num_params = 300
    param_names = [f"p{i}" for i in range(num_params)]
    manager._tools["even_tool"] = ToolDefinition(
        name="even_tool",
        parameters={"required": param_names}
    )
    manager._validation_functions["even_tool"] = validator
    args = {name: i*2 for i, name in enumerate(param_names)}
    valid, msg = manager._validate_backend_tool_arguments("even_tool", args)
    # Now one odd value
    args[param_names[0]] = 1
    valid, msg = manager._validate_backend_tool_arguments("even_tool", args)

def test_large_scale_missing_all_required(manager):
    """Test: Tool with many required params, none provided."""
    num_params = 400
    param_names = [f"p{i}" for i in range(num_params)]
    manager._tools["missing_tool"] = ToolDefinition(
        name="missing_tool",
        parameters={"required": param_names}
    )
    args = {}
    valid, msg = manager._validate_backend_tool_arguments("missing_tool", args)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from typing import Any, Callable, Dict, List, Optional, Tuple

# imports
import pytest
from marimo._server.ai.tools.tool_manager import ToolManager

# --- Minimal stubs for dependencies ---

class ToolDefinition:
    def __init__(self, name: str, parameters: Dict[str, Any], source: str):
        self.name = name
        self.parameters = parameters
        self.source = source

class DummyApp:
    pass

# --- Pytest test cases ---

@pytest.fixture
def tool_manager():
    # Setup a ToolManager with a few tools
    tm = ToolManager(DummyApp())
    # Tool with one required param
    tm._tools["simple_tool"] = ToolDefinition(
        name="simple_tool",
        parameters={"required": ["x"]},
        source="backend"
    )
    # Tool with two required params
    tm._tools["two_param_tool"] = ToolDefinition(
        name="two_param_tool",
        parameters={"required": ["a", "b"]},
        source="backend"
    )
    # Tool with no required params
    tm._tools["no_required_tool"] = ToolDefinition(
        name="no_required_tool",
        parameters={"required": []},
        source="backend"
    )
    # Tool with custom validation
    def validator(args):
        # Only accepts x > 0
        x = args.get("x", None)
        if x is None:
            return False, "Missing x"
        if x <= 0:
            return False, "x must be positive"
        return True, ""
    tm._tools["validated_tool"] = ToolDefinition(
        name="validated_tool",
        parameters={"required": ["x"]},
        source="backend"
    )
    tm._validation_functions["validated_tool"] = validator
    # Tool with validator that raises
    def error_validator(args):
        raise ValueError("Validation error!")
    tm._tools["error_tool"] = ToolDefinition(
        name="error_tool",
        parameters={"required": ["x"]},
        source="backend"
    )
    tm._validation_functions["error_tool"] = error_validator
    return tm

# --- Basic Test Cases ---

def test_basic_success(tool_manager):
    # All required param present
    result, msg = tool_manager._validate_backend_tool_arguments("simple_tool", {"x": 5}) # 435μs -> 1.68μs (25869% faster)

def test_basic_missing_param(tool_manager):
    # Missing required param
    result, msg = tool_manager._validate_backend_tool_arguments("simple_tool", {}) # 433μs -> 68.1μs (537% faster)

def test_basic_multiple_params_success(tool_manager):
    # All required params present
    result, msg = tool_manager._validate_backend_tool_arguments("two_param_tool", {"a": 1, "b": 2}) # 399μs -> 1.38μs (28876% faster)

def test_basic_multiple_params_missing(tool_manager):
    # One required param missing
    result, msg = tool_manager._validate_backend_tool_arguments("two_param_tool", {"a": 1}) # 424μs -> 58.2μs (629% faster)

def test_basic_no_required(tool_manager):
    # No required params, should always pass
    result, msg = tool_manager._validate_backend_tool_arguments("no_required_tool", {}) # 398μs -> 1.26μs (31627% faster)

def test_basic_tool_not_found(tool_manager):
    # Tool does not exist
    result, msg = tool_manager._validate_backend_tool_arguments("nonexistent_tool", {"x": 1}) # 387μs -> 144μs (168% faster)

# --- Edge Test Cases ---

def test_edge_empty_arguments(tool_manager):
    # Arguments is empty dict, but required param exists
    result, msg = tool_manager._validate_backend_tool_arguments("simple_tool", {}) # 417μs -> 50.8μs (722% faster)

def test_edge_extra_arguments(tool_manager):
    # Extra arguments should not affect validation
    result, msg = tool_manager._validate_backend_tool_arguments("simple_tool", {"x": 5, "y": 99}) # 384μs -> 1.38μs (27776% faster)

def test_edge_argument_is_none(tool_manager):
    # Argument present but value is None should still count as present
    result, msg = tool_manager._validate_backend_tool_arguments("simple_tool", {"x": None}) # 389μs -> 1.31μs (29655% faster)

def test_edge_custom_validator_success(tool_manager):
    # Passes custom validator
    result, msg = tool_manager._validate_backend_tool_arguments("validated_tool", {"x": 10}) # 389μs -> 1.48μs (26296% faster)

def test_edge_custom_validator_fail(tool_manager):
    # Fails custom validator
    result, msg = tool_manager._validate_backend_tool_arguments("validated_tool", {"x": -5}) # 423μs -> 60.7μs (598% faster)

def test_edge_custom_validator_missing_param(tool_manager):
    # Custom validator: missing param
    result, msg = tool_manager._validate_backend_tool_arguments("validated_tool", {}) # 424μs -> 49.2μs (762% faster)

def test_edge_validator_raises(tool_manager):
    # Validator raises exception
    result, msg = tool_manager._validate_backend_tool_arguments("error_tool", {"x": 5}) # 431μs -> 49.7μs (768% faster)

def test_edge_required_param_is_list(tool_manager):
    # Required param is a list, argument is a list
    tm = ToolManager(DummyApp())
    tm._tools["list_tool"] = ToolDefinition(
        name="list_tool",
        parameters={"required": ["items"]},
        source="backend"
    )
    result, msg = tm._validate_backend_tool_arguments("list_tool", {"items": [1,2,3]}) # 389μs -> 1.39μs (27973% faster)

def test_edge_required_param_empty_string(tool_manager):
    # Required param present but value is empty string
    tm = ToolManager(DummyApp())
    tm._tools["str_tool"] = ToolDefinition(
        name="str_tool",
        parameters={"required": ["name"]},
        source="backend"
    )
    result, msg = tm._validate_backend_tool_arguments("str_tool", {"name": ""}) # 386μs -> 1.29μs (29772% faster)

def test_edge_required_param_false(tool_manager):
    # Required param present but value is False
    tm = ToolManager(DummyApp())
    tm._tools["bool_tool"] = ToolDefinition(
        name="bool_tool",
        parameters={"required": ["flag"]},
        source="backend"
    )
    result, msg = tm._validate_backend_tool_arguments("bool_tool", {"flag": False}) # 389μs -> 1.27μs (30547% faster)

# --- Large Scale Test Cases ---

def test_large_scale_many_required_params(tool_manager):
    # Tool with 500 required params, all present
    tm = ToolManager(DummyApp())
    required = [f"p{i}" for i in range(500)]
    tm._tools["big_tool"] = ToolDefinition(
        name="big_tool",
        parameters={"required": required},
        source="backend"
    )
    args = {f"p{i}": i for i in range(500)}
    result, msg = tm._validate_backend_tool_arguments("big_tool", args) # 434μs -> 18.9μs (2197% faster)

def test_large_scale_many_missing_params(tool_manager):
    # Tool with 500 required params, 1 missing
    tm = ToolManager(DummyApp())
    required = [f"p{i}" for i in range(500)]
    tm._tools["big_tool"] = ToolDefinition(
        name="big_tool",
        parameters={"required": required},
        source="backend"
    )
    args = {f"p{i}": i for i in range(499)}  # missing p499
    result, msg = tm._validate_backend_tool_arguments("big_tool", args) # 453μs -> 82.2μs (452% faster)

def test_large_scale_custom_validator(tool_manager):
    # Tool with custom validator that checks all params are positive
    def validator(args):
        for k, v in args.items():
            if v <= 0:
                return False, f"{k} must be positive"
        return True, ""
    tm = ToolManager(DummyApp())
    required = [f"p{i}" for i in range(100)]
    tm._tools["pos_tool"] = ToolDefinition(
        name="pos_tool",
        parameters={"required": required},
        source="backend"
    )
    tm._validation_functions["pos_tool"] = validator
    args = {f"p{i}": i+1 for i in range(100)}
    result, msg = tm._validate_backend_tool_arguments("pos_tool", args) # 405μs -> 3.77μs (10641% faster)

def test_large_scale_custom_validator_fail(tool_manager):
    # Tool with custom validator, one param negative
    def validator(args):
        for k, v in args.items():
            if v <= 0:
                return False, f"{k} must be positive"
        return True, ""
    tm = ToolManager(DummyApp())
    required = [f"p{i}" for i in range(100)]
    tm._tools["pos_tool"] = ToolDefinition(
        name="pos_tool",
        parameters={"required": required},
        source="backend"
    )
    tm._validation_functions["pos_tool"] = validator
    args = {f"p{i}": i+1 for i in range(99)}
    args["p99"] = -1
    result, msg = tm._validate_backend_tool_arguments("pos_tool", args) # 430μs -> 60.9μs (608% faster)

def test_large_scale_no_required(tool_manager):
    # Tool with zero required params, many args provided
    tm = ToolManager(DummyApp())
    tm._tools["free_tool"] = ToolDefinition(
        name="free_tool",
        parameters={"required": []},
        source="backend"
    )
    args = {f"x{i}": i for i in range(1000)}
    result, msg = tm._validate_backend_tool_arguments("free_tool", args) # 405μs -> 1.24μs (32603% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-ToolManager._validate_backend_tool_arguments-mhwt455a and push.

Codeflash Static Badge

The optimization achieves a **1206% speedup** by eliminating an expensive operation in the `_get_tool` method's fallback path when no source is specified.

**Key optimization:** In the original code, when `source=None`, the method called `self._get_all_tools()` which gathered all backend tools AND all MCP tools from an external client. The line profiler shows this single call consumed **99.8% of execution time** (21.5ms out of 21.5ms total).

The optimized version replaces this expensive all-tools scan with a **two-stage lookup**:
1. First, directly check the backend tools dictionary (`self._tools.get(name)`)  
2. Only if not found, then query MCP tools (`self._list_mcp_tools()`)

**Why this works:** Most tool lookups likely target backend tools stored in the local dictionary. The optimization leverages this by checking the fast O(1) dictionary lookup first, avoiding the expensive MCP client call in the common case.

**Performance impact:** The test results show dramatic improvements across all scenarios:
- Simple successful lookups: **25,000-31,000% faster** (sub-microsecond vs hundreds of microseconds)
- Missing tool scenarios: **168-768% faster** 
- Large-scale tests with 500+ parameters: **2,197-10,641% faster**

This optimization is particularly valuable if the function is called frequently during tool validation workflows, as it transforms an expensive operation into a fast dictionary lookup for the majority of cases.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 13, 2025 02:25
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant