diff --git a/src/dedalus_labs/lib/runner/_scheduler.py b/src/dedalus_labs/lib/runner/_scheduler.py new file mode 100644 index 0000000..fe94b3a --- /dev/null +++ b/src/dedalus_labs/lib/runner/_scheduler.py @@ -0,0 +1,293 @@ +# ============================================================================== +# © 2025 Dedalus Labs, Inc. and affiliates +# Licensed under MIT +# github.com/dedalus-labs/dedalus-sdk-python/LICENSE +# ============================================================================== + +"""Dependency-aware local tool scheduler. + +When the API server returns ``pending_local_calls`` with dependency +info, this module topo-sorts and executes them in parallel layers. +Independent tools fire concurrently; dependent tools wait for their +prerequisites. + +Falls back to sequential execution when dependencies form a cycle +(model hallucinated wrong deps). + +Functions: + execute_local_tools_async -- async path, uses asyncio.gather per layer + execute_local_tools_sync -- sync path, sequential within layers +""" + +from __future__ import annotations + +import json +import asyncio +from typing import Any, Dict, List, Tuple + +from graphlib import CycleError, TopologicalSorter + + +def _parse_pending_calls( + tool_calls: List[Dict[str, Any]], +) -> Tuple[Dict[str, Dict[str, Any]], TopologicalSorter]: + """Parse pending local calls and build a TopologicalSorter. + + Returns (calls_by_id, sorter). Each entry in calls_by_id has + the parsed function name and arguments ready for execution. + + Raises CycleError if dependencies are cyclic. + """ + calls_by_id: Dict[str, Dict[str, Any]] = {} + sorter: TopologicalSorter = TopologicalSorter() + known_ids = {tc["id"] for tc in tool_calls if "id" in tc} + + for tc in tool_calls: + call_id = tc.get("id", "") + fn_name = tc["function"]["name"] + fn_args_str = tc["function"]["arguments"] + + try: + fn_args = json.loads(fn_args_str) if isinstance(fn_args_str, str) else fn_args_str + except json.JSONDecodeError: + fn_args = {} + + # Filter deps to only known ids in this batch. + raw_deps = tc.get("dependencies") or [] + deps = [dep for dep in raw_deps if dep in known_ids and dep != call_id] + + calls_by_id[call_id] = {"name": fn_name, "args": fn_args, "id": call_id} + sorter.add(call_id, *deps) + + sorter.prepare() + return calls_by_id, sorter + + +async def execute_local_tools_async( + tool_calls: List[Dict[str, Any]], + tool_handler: Any, + messages: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]], + tools_called: List[str], + step: int, + *, + verbose: bool = False, +) -> None: + """Execute local tool calls respecting dependencies (async). + + Independent tools within the same topo layer fire concurrently + via asyncio.gather. Falls back to sequential on cycle. + + The caller is responsible for appending the assistant message + (with tool_calls and any reasoning_content) before calling this. + """ + if not tool_calls: + return + + try: + calls_by_id, sorter = _parse_pending_calls(tool_calls) + except CycleError: + # If wrong deps from model, fall back to sequential. + await _execute_sequential_async( + tool_calls, + tool_handler, + messages, + tool_results, + tools_called, + step, + verbose, + ) + return + + # Drive the sorter layer by layer. + while sorter.is_active(): + ready = list(sorter.get_ready()) + if not ready: + break + + if len(ready) == 1: + # Single tool: no gather overhead. + call_id = ready[0] + await _run_one_async( + calls_by_id[call_id], + tool_handler, + messages, + tool_results, + tools_called, + step, + verbose, + ) + sorter.done(call_id) + else: + # Multiple independent tools: fire concurrently. + results = await asyncio.gather( + *[ + _run_one_async( + calls_by_id[call_id], + tool_handler, + messages, + tool_results, + tools_called, + step, + verbose, + ) + for call_id in ready + ], + return_exceptions=True, + ) + + for call_id, result in zip(ready, results): + if isinstance(result, Exception): + # Already recorded in messages by _run_one_async. + pass + sorter.done(call_id) + + +def execute_local_tools_sync( + tool_calls: List[Dict[str, Any]], + tool_handler: Any, + messages: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]], + tools_called: List[str], + step: int, +) -> None: + """Execute local tool calls respecting dependencies (sync). + + Executes in topo order, one at a time. No parallelism in sync mode, + but ordering is correct. + + The caller is responsible for appending the assistant message + (with tool_calls and any reasoning_content) before calling this. + """ + if not tool_calls: + return + + try: + calls_by_id, sorter = _parse_pending_calls(tool_calls) + except CycleError: + _execute_sequential_sync( + tool_calls, + tool_handler, + messages, + tool_results, + tools_called, + step, + ) + return + + while sorter.is_active(): + ready = list(sorter.get_ready()) + if not ready: + break + for call_id in ready: + _run_one_sync( + calls_by_id[call_id], + tool_handler, + messages, + tool_results, + tools_called, + step, + ) + sorter.done(call_id) + + +# --- Single tool execution --- + + +async def _run_one_async( + call: Dict[str, Any], + tool_handler: Any, + messages: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]], + tools_called: List[str], + step: int, + verbose: bool, +) -> None: + """Execute a single tool call and record results.""" + fn_name = call["name"] + fn_args = call["args"] + call_id = call["id"] + + try: + result = await tool_handler.exec(fn_name, fn_args) + tool_results.append({"name": fn_name, "result": result, "step": step}) + tools_called.append(fn_name) + messages.append({"role": "tool", "tool_call_id": call_id, "content": str(result)}) + if verbose: + print(f" Tool {fn_name}: {str(result)[:50]}...") # noqa: T201 + except Exception as e: + tool_results.append({"error": str(e), "name": fn_name, "step": step}) + messages.append({"role": "tool", "tool_call_id": call_id, "content": f"Error: {e}"}) + if verbose: + print(f" Tool {fn_name} failed: {e}") # noqa: T201 + + +def _run_one_sync( + call: Dict[str, Any], + tool_handler: Any, + messages: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]], + tools_called: List[str], + step: int, +) -> None: + """Execute a single tool call synchronously and record results.""" + fn_name = call["name"] + fn_args = call["args"] + call_id = call["id"] + + try: + result = tool_handler.exec_sync(fn_name, fn_args) + tool_results.append({"name": fn_name, "result": result, "step": step}) + tools_called.append(fn_name) + messages.append({"role": "tool", "tool_call_id": call_id, "content": str(result)}) + except Exception as e: + tool_results.append({"error": str(e), "name": fn_name, "step": step}) + messages.append({"role": "tool", "tool_call_id": call_id, "content": f"Error: {e}"}) + + +# --- Sequential fallback --- + + +async def _execute_sequential_async( + tool_calls: List[Dict[str, Any]], + tool_handler: Any, + messages: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]], + tools_called: List[str], + step: int, + verbose: bool, +) -> None: + """Fallback: execute all tools sequentially (no dependency ordering).""" + for tc in tool_calls: + fn_name = tc["function"]["name"] + fn_args_str = tc["function"]["arguments"] + call_id = tc.get("id", "") + try: + fn_args = json.loads(fn_args_str) if isinstance(fn_args_str, str) else fn_args_str + except json.JSONDecodeError: + fn_args = {} + + call = {"name": fn_name, "args": fn_args, "id": call_id} + await _run_one_async(call, tool_handler, messages, tool_results, tools_called, step, verbose) + + +def _execute_sequential_sync( + tool_calls: List[Dict[str, Any]], + tool_handler: Any, + messages: List[Dict[str, Any]], + tool_results: List[Dict[str, Any]], + tools_called: List[str], + step: int, +) -> None: + """Fallback: execute all tools sequentially (no dependency ordering).""" + for tc in tool_calls: + fn_name = tc["function"]["name"] + fn_args_str = tc["function"]["arguments"] + call_id = tc.get("id", "") + try: + fn_args = json.loads(fn_args_str) if isinstance(fn_args_str, str) else fn_args_str + except json.JSONDecodeError: + fn_args = {} + + call = {"name": fn_name, "args": fn_args, "id": call_id} + _run_one_sync(call, tool_handler, messages, tool_results, tools_called, step) diff --git a/src/dedalus_labs/lib/runner/core.py b/src/dedalus_labs/lib/runner/core.py index 8ed01a4..94f0acf 100644 --- a/src/dedalus_labs/lib/runner/core.py +++ b/src/dedalus_labs/lib/runner/core.py @@ -6,31 +6,29 @@ from __future__ import annotations -import json import asyncio import inspect from typing import ( TYPE_CHECKING, Any, Dict, + Union, Literal, Callable, Iterator, Protocol, - AsyncIterator, Sequence, - Union, + AsyncIterator, ) -from dataclasses import field, asdict, dataclass +from dataclasses import field, dataclass if TYPE_CHECKING: from ...types.shared.dedalus_model import DedalusModel -from ..._client import Dedalus, AsyncDedalus - +from ..mcp import MCPServerProtocol, serialize_mcp_servers from .types import Message, ToolCall, JsonValue, ToolResult, PolicyInput, PolicyContext +from ..._client import Dedalus, AsyncDedalus from ...types.shared import MCPToolResult -from ..mcp import serialize_mcp_servers, MCPServerProtocol # Type alias for mcp_servers parameter - accepts strings, server objects, or mixed lists MCPServersInput = Union[ @@ -120,22 +118,17 @@ def exec_sync(self, name: str, args: Dict[str, JsonValue]) -> JsonValue: @dataclass class _ModelConfig: - """Model configuration parameters.""" + """Model routing info + passthrough API kwargs. + + ``api_kwargs`` holds every parameter destined for the chat + completions API (temperature, reasoning_effort, thinking, etc.). + The runner doesn't interpret most of them — it just forwards + them to ``client.chat.completions.create(**api_kwargs)``. + """ id: str - model_list: list[str] | None = None # Store the full model list when provided - temperature: float | None = None - max_tokens: int | None = None - top_p: float | None = None - frequency_penalty: float | None = None - presence_penalty: float | None = None - logit_bias: Dict[str, int] | None = None - response_format: Dict[str, JsonValue] | type | None = None # Dict or Pydantic model - agent_attributes: Dict[str, float] | None = None - model_attributes: Dict[str, Dict[str, float]] | None = None - tool_choice: str | Dict[str, JsonValue] | None = None - guardrails: list[Dict[str, JsonValue]] | None = None - handoff_config: Dict[str, JsonValue] | None = None + model_list: list[str] | None = None + api_kwargs: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -184,6 +177,97 @@ def to_input_list(self) -> list[Message]: return list(self.messages) +def _collect_api_kwargs(**params: Any) -> Dict[str, Any]: + """Build API kwargs dict from explicit params, filtering out Nones.""" + return {k: v for k, v in params.items() if v is not None} + + +# Params that DedalusModel may carry and should be extracted. +_MODEL_EXTRACT_PARAMS = ( + "temperature", + "max_tokens", + "top_p", + "frequency_penalty", + "presence_penalty", + "logit_bias", + "tool_choice", + "reasoning_effort", + "thinking", + "n", + "stop", + "stream_options", + "logprobs", + "top_logprobs", + "seed", + "service_tier", + "parallel_tool_calls", + "user", + "max_completion_tokens", +) + + +def _extract_from_dedalus_model( + model_obj: Any, + api_kwargs: Dict[str, Any], +) -> bool: + """Extract params from a DedalusModel into api_kwargs. + + Explicit values already in api_kwargs take precedence. + Returns True if stream should be overridden from the model. + """ + for param in _MODEL_EXTRACT_PARAMS: + if param not in api_kwargs: + val = getattr(model_obj, param, None) + if val is not None: + api_kwargs[param] = val + + # Dedalus-specific: attributes → agent_attributes + if "agent_attributes" not in api_kwargs: + attrs = getattr(model_obj, "attributes", None) + if attrs: + api_kwargs["agent_attributes"] = attrs + + return getattr(model_obj, "stream", False) + + +def _parse_model( + model: Any, + api_kwargs: Dict[str, Any], + stream: bool, +) -> tuple: + """Parse model param into (model_name, model_list, stream). + + Handles strings, DedalusModel objects, and lists of either. + Extracts model-embedded params into api_kwargs. + """ + if isinstance(model, list): + if not model: + raise ValueError("model list cannot be empty") + model_name = None + model_list = [] + for m in model: + if hasattr(m, "name"): + model_list.append(m.name) + if model_name is None: + model_name = m.name + model_stream = _extract_from_dedalus_model(m, api_kwargs) + if not stream: + stream = model_stream + else: + model_list.append(m) + if model_name is None: + model_name = m + return model_name, model_list, stream + + if hasattr(model, "name"): + model_stream = _extract_from_dedalus_model(model, api_kwargs) + if not stream: + stream = model_stream + return model.name, [model.name], stream + + return model, [model] if model else [], stream + + class DedalusRunner: """Enhanced Dedalus client with tool execution capabilities.""" @@ -198,32 +282,71 @@ def run( messages: list[Message] | None = None, instructions: str | None = None, model: str | list[str] | DedalusModel | list[DedalusModel] | None = None, + # --- Runner config --- max_steps: int = 10, mcp_servers: MCPServersInput = None, - credentials: Sequence[Any] | None = None, # TODO: Loosely typed as `Any` for now - temperature: float | None = None, - max_tokens: int | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - logit_bias: Dict[str, int] | None = None, - response_format: Dict[str, JsonValue] | type | None = None, + credentials: Sequence[Any] | None = None, stream: bool = False, transport: Literal["http", "realtime"] = "http", verbose: bool | None = None, debug: bool | None = None, on_tool_event: Callable[[Dict[str, JsonValue]], None] | None = None, return_intent: bool = False, - agent_attributes: Dict[str, float] | None = None, - model_attributes: Dict[str, Dict[str, float]] | None = None, - tool_choice: str | Dict[str, JsonValue] | None = None, - guardrails: list[Dict[str, JsonValue]] | None = None, - handoff_config: Dict[str, JsonValue] | None = None, policy: PolicyInput = None, available_models: list[str] | None = None, strict_models: bool = True, + agent_attributes: Dict[str, float] | None = None, + audio: Dict[str, Any] | None = None, + cached_content: str | None = None, + deferred: bool | None = None, + frequency_penalty: float | None = None, + function_call: str | None = None, + generation_config: Dict[str, Any] | None = None, + guardrails: list[Dict[str, JsonValue]] | None = None, + handoff_config: Dict[str, JsonValue] | None = None, + logit_bias: Dict[str, int] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + metadata: Dict[str, Any] | None = None, + modalities: list[str] | None = None, + model_attributes: Dict[str, Dict[str, float]] | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + prediction: Dict[str, Any] | None = None, + presence_penalty: float | None = None, + prompt_cache_key: str | None = None, + prompt_cache_retention: str | None = None, + prompt_mode: str | None = None, + reasoning_effort: str | None = None, + response_format: Dict[str, JsonValue] | type | None = None, + safe_prompt: bool | None = None, + safety_identifier: str | None = None, + safety_settings: list[Dict[str, Any]] | None = None, + search_parameters: Dict[str, Any] | None = None, + seed: int | None = None, + service_tier: str | None = None, + stop: str | list[str] | None = None, + store: bool | None = None, + stream_options: Dict[str, Any] | None = None, + system_instruction: str | Dict[str, Any] | None = None, + temperature: float | None = None, + thinking: Dict[str, Any] | None = None, + tool_choice: str | Dict[str, JsonValue] | None = None, + tool_config: Dict[str, Any] | None = None, + top_k: int | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + verbosity: str | None = None, + web_search_options: Dict[str, Any] | None = None, ): - """Execute tools with unified async/sync + streaming/non-streaming logic.""" + """Execute a tool-enabled conversation. + + All parameters from the chat completions API are accepted and + forwarded to the server verbatim. See ``CompletionCreateParamsBase`` + for full documentation of each parameter. + """ if not model: raise ValueError("model must be provided") @@ -233,7 +356,6 @@ def run( msg = "tools must be a list of callable functions or None" raise ValueError(msg) - # Check for nested lists (common mistake: tools=[[]] instead of tools=[]) for i, tool in enumerate(tools): if not callable(tool): if isinstance(tool, list): @@ -244,160 +366,64 @@ def run( ) raise TypeError(msg) - # Parse model to extract name and config - model_name = None - model_list = [] + # Collect all API kwargs, filtering out Nones. + api_kwargs = _collect_api_kwargs( + agent_attributes=agent_attributes, + audio=audio, + cached_content=cached_content, + deferred=deferred, + frequency_penalty=frequency_penalty, + function_call=function_call, + generation_config=generation_config, + guardrails=guardrails, + handoff_config=handoff_config, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + metadata=metadata, + modalities=modalities, + model_attributes=model_attributes, + n=n, + parallel_tool_calls=parallel_tool_calls, + prediction=prediction, + presence_penalty=presence_penalty, + prompt_cache_key=prompt_cache_key, + prompt_cache_retention=prompt_cache_retention, + prompt_mode=prompt_mode, + reasoning_effort=reasoning_effort, + response_format=response_format, + safe_prompt=safe_prompt, + safety_identifier=safety_identifier, + safety_settings=safety_settings, + search_parameters=search_parameters, + seed=seed, + service_tier=service_tier, + stop=stop, + store=store, + stream_options=stream_options, + system_instruction=system_instruction, + temperature=temperature, + thinking=thinking, + tool_choice=tool_choice, + tool_config=tool_config, + top_k=top_k, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + verbosity=verbosity, + web_search_options=web_search_options, + ) - if isinstance(model, list): - if not model: - raise ValueError("model list cannot be empty") - # Handle list of DedalusModel or strings - for m in model: - if hasattr(m, "name"): # DedalusModel - model_list.append(m.name) - # Use config from first DedalusModel if params not explicitly set - if model_name is None: - model_name = m.name - temperature = temperature if temperature is not None else getattr(m, "temperature", None) - max_tokens = max_tokens if max_tokens is not None else getattr(m, "max_tokens", None) - top_p = top_p if top_p is not None else getattr(m, "top_p", None) - frequency_penalty = ( - frequency_penalty - if frequency_penalty is not None - else getattr(m, "frequency_penalty", None) - ) - presence_penalty = ( - presence_penalty if presence_penalty is not None else getattr(m, "presence_penalty", None) - ) - logit_bias = logit_bias if logit_bias is not None else getattr(m, "logit_bias", None) - - # Extract additional parameters from first DedalusModel - stream = stream if stream is not False else getattr(m, "stream", False) - tool_choice = tool_choice if tool_choice is not None else getattr(m, "tool_choice", None) - - # Extract Dedalus-specific extensions - if hasattr(m, "attributes") and m.attributes: - agent_attributes = agent_attributes if agent_attributes is not None else m.attributes - - # Check for unsupported parameters (only warn once for first model) - unsupported_params = [] - if hasattr(m, "n") and m.n is not None: - unsupported_params.append("n") - if hasattr(m, "stop") and m.stop is not None: - unsupported_params.append("stop") - if hasattr(m, "stream_options") and m.stream_options is not None: - unsupported_params.append("stream_options") - if hasattr(m, "logprobs") and m.logprobs is not None: - unsupported_params.append("logprobs") - if hasattr(m, "top_logprobs") and m.top_logprobs is not None: - unsupported_params.append("top_logprobs") - if hasattr(m, "seed") and m.seed is not None: - unsupported_params.append("seed") - if hasattr(m, "service_tier") and m.service_tier is not None: - unsupported_params.append("service_tier") - if hasattr(m, "tools") and m.tools is not None: - unsupported_params.append("tools") - if hasattr(m, "parallel_tool_calls") and m.parallel_tool_calls is not None: - unsupported_params.append("parallel_tool_calls") - if hasattr(m, "user") and m.user is not None: - unsupported_params.append("user") - if hasattr(m, "max_completion_tokens") and m.max_completion_tokens is not None: - unsupported_params.append("max_completion_tokens") - - if unsupported_params: - import warnings - - warnings.warn( - f"The following DedalusModel parameters are not yet supported and will be ignored: {', '.join(unsupported_params)}. " - f"Support for these parameters is coming soon.", - UserWarning, - stacklevel=2, - ) - else: # String - model_list.append(m) - if model_name is None: - model_name = m - elif hasattr(model, "name"): # Single DedalusModel - model_name = model.name - model_list = [model.name] - # Extract config from DedalusModel if params not explicitly set - temperature = temperature if temperature is not None else getattr(model, "temperature", None) - max_tokens = max_tokens if max_tokens is not None else getattr(model, "max_tokens", None) - top_p = top_p if top_p is not None else getattr(model, "top_p", None) - frequency_penalty = ( - frequency_penalty if frequency_penalty is not None else getattr(model, "frequency_penalty", None) - ) - presence_penalty = ( - presence_penalty if presence_penalty is not None else getattr(model, "presence_penalty", None) - ) - logit_bias = logit_bias if logit_bias is not None else getattr(model, "logit_bias", None) - - # Extract additional supported parameters - stream = stream if stream is not False else getattr(model, "stream", False) - tool_choice = tool_choice if tool_choice is not None else getattr(model, "tool_choice", None) - - # Extract Dedalus-specific extensions - if hasattr(model, "attributes") and model.attributes: - agent_attributes = agent_attributes if agent_attributes is not None else model.attributes - if hasattr(model, "metadata") and model.metadata: - # metadata is stored but not yet fully utilized - pass - - # Log warnings for unsupported parameters - unsupported_params = [] - if hasattr(model, "n") and model.n is not None: - unsupported_params.append("n") - if hasattr(model, "stop") and model.stop is not None: - unsupported_params.append("stop") - if hasattr(model, "stream_options") and model.stream_options is not None: - unsupported_params.append("stream_options") - if hasattr(model, "logprobs") and model.logprobs is not None: - unsupported_params.append("logprobs") - if hasattr(model, "top_logprobs") and model.top_logprobs is not None: - unsupported_params.append("top_logprobs") - if hasattr(model, "seed") and model.seed is not None: - unsupported_params.append("seed") - if hasattr(model, "service_tier") and model.service_tier is not None: - unsupported_params.append("service_tier") - if hasattr(model, "tools") and model.tools is not None: - unsupported_params.append("tools") - if hasattr(model, "parallel_tool_calls") and model.parallel_tool_calls is not None: - unsupported_params.append("parallel_tool_calls") - if hasattr(model, "user") and model.user is not None: - unsupported_params.append("user") - if hasattr(model, "max_completion_tokens") and model.max_completion_tokens is not None: - unsupported_params.append("max_completion_tokens") - - if unsupported_params: - import warnings - - warnings.warn( - f"The following DedalusModel parameters are not yet supported and will be ignored: {', '.join(unsupported_params)}. " - f"Support for these parameters is coming soon.", - UserWarning, - stacklevel=2, - ) - else: # Single string - model_name = model - model_list = [model] if model else [] + # Parse model to extract name, list, and any model-embedded params. + model_name, model_list, stream = _parse_model(model, api_kwargs, stream) available_models = model_list if available_models is None else available_models model_config = _ModelConfig( id=str(model_name), - model_list=model_list, # Pass the full model list - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - logit_bias=logit_bias, - response_format=response_format, - agent_attributes=agent_attributes, - model_attributes=model_attributes, - tool_choice=tool_choice, - guardrails=guardrails, - handoff_config=handoff_config, + model_list=model_list, + api_kwargs=api_kwargs, ) # Serialize mcp_servers to wire format @@ -728,80 +754,33 @@ async def _execute_streaming_async( print(f" Local tools used: {local_names}") print(f" Server tools used: {mcp_names}") - # When MCP tools are involved and content was streamed, we're done - if mcp_names and has_streamed_content: + # All tools are server side and results have already been streamed. + if all_mcp and has_streamed_content: if exec_config.verbose: - print(f" MCP tools called and content streamed - response complete, breaking loop") + print(f" All tools are MCP and content streamed, breaking loop") break - if all_mcp: - # All tools are MCP - the response should be streamed - if exec_config.verbose: - print(f" All tools are MCP, expecting streamed response") - # Don't break here - let the next iteration handle it - else: - # We have at least one local tool - # Filter to only include local tool calls in the assistant message - local_only_tool_calls = [ + # At least one local tool exists. Execute via the dependency aware scheduler. + if not all_mcp: + local_only = [ tc for tc in tool_calls if tc["function"]["name"] in getattr(tool_handler, "_funcs", {}) ] - messages.append({"role": "assistant", "tool_calls": local_only_tool_calls}) - if exec_config.verbose: - print( - f" Added assistant message with {len(local_only_tool_calls)} local tool calls (filtered from {len(tool_calls)} total)" - ) - - # Execute only local tools - for tc in tool_calls: - fn_name = tc["function"]["name"] - fn_args_str = tc["function"]["arguments"] - - if fn_name in getattr(tool_handler, "_funcs", {}): - # Local tool - try: - fn_args = json.loads(fn_args_str) - except json.JSONDecodeError: - fn_args = {} - - try: - result = await tool_handler.exec(fn_name, fn_args) - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": str(result), - } - ) - if exec_config.verbose: - print(f" Executed local tool {fn_name}: {str(result)[:50]}...") - except Exception as e: - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": f"Error: {str(e)}", - } - ) - if exec_config.verbose: - print(f" Error executing local tool {fn_name}: {e}") - else: - # MCP tool - DON'T add any message - # The API server should handle this - if exec_config.verbose: - print(f" MCP tool {fn_name} - skipping (server will handle)") + messages.append({"role": "assistant", "tool_calls": local_only}) + + from ._scheduler import execute_local_tools_async + + await execute_local_tools_async( + local_only, + tool_handler, + messages, + [], + [], + steps, + verbose=exec_config.verbose, + ) if exec_config.verbose: print(f" Messages after tool execution: {len(messages)}") - - # Only continue if we have NO MCP tools - if not mcp_names: - print(f" No MCP tools, continuing loop to step {steps + 1}...") - else: - print(f" MCP tools present, expecting response in next iteration") - - # Continue loop only if we need another response - if exec_config.verbose: - print(f" Tool processing complete") else: if exec_config.verbose: print(f" No tool calls found, breaking out of loop") @@ -1058,80 +1037,32 @@ def _execute_streaming_sync( print(f" Local tools: {local_names}") print(f" Server tools: {mcp_names}") - # When MCP tools are involved and content was streamed, we're done - if mcp_names and has_streamed_content: + # All tools are server side and results have already been streamed. + if all_mcp and has_streamed_content: if exec_config.verbose: - print(f" MCP tools called and content streamed - response complete, breaking loop") + print(f" All tools are MCP and content streamed, breaking loop") break - if all_mcp: - # All tools are MCP - the response should be streamed - if exec_config.verbose: - print(f" All tools are MCP, expecting streamed response") - # Don't break here - let the next iteration handle it - else: - # We have at least one local tool - # Filter to only include local tool calls in the assistant message - local_only_tool_calls = [ + # At least one local tool exists. Execute via the dependency aware scheduler. + if not all_mcp: + local_only = [ tc for tc in tool_calls if tc["function"]["name"] in getattr(tool_handler, "_funcs", {}) ] - messages.append({"role": "assistant", "tool_calls": local_only_tool_calls}) - if exec_config.verbose: - print( - f" Added assistant message with {len(local_only_tool_calls)} local tool calls (filtered from {len(tool_calls)} total)" - ) + messages.append({"role": "assistant", "tool_calls": local_only}) - # Execute only local tools - for tc in tool_calls: - fn_name = tc["function"]["name"] - fn_args_str = tc["function"]["arguments"] - - if fn_name in getattr(tool_handler, "_funcs", {}): - # Local tool - try: - fn_args = json.loads(fn_args_str) - except json.JSONDecodeError: - fn_args = {} - - try: - result = tool_handler.exec_sync(fn_name, fn_args) - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": str(result), - } - ) - if exec_config.verbose: - print(f" Executed local tool {fn_name}: {str(result)[:50]}...") - except Exception as e: - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": f"Error: {str(e)}", - } - ) - if exec_config.verbose: - print(f" Error executing local tool {fn_name}: {e}") - else: - # MCP tool - DON'T add any message - # The API server should handle this - if exec_config.verbose: - print(f" MCP tool {fn_name} - skipping (server will handle)") + from ._scheduler import execute_local_tools_sync - if exec_config.verbose: - print(f" Messages after tool execution: {len(messages)}") - - # Only continue if we have NO MCP tools - if not mcp_names: - print(f" No MCP tools, continuing loop to step {steps + 1}...") - else: - print(f" MCP tools present, expecting response in next iteration") + execute_local_tools_sync( + local_only, + tool_handler, + messages, + [], + [], + steps, + ) - # Continue loop only if we need another response - if exec_config.verbose: - print(f" Tool processing complete") + if exec_config.verbose: + print(f" Messages after tool execution: {len(messages)}") else: if exec_config.verbose: print(f" No tool calls found, breaking out of loop") @@ -1244,47 +1175,28 @@ async def _execute_tool_calls( step: int, verbose: bool = False, ): - """Execute tool calls asynchronously.""" + """Execute tool calls asynchronously with dependency-aware scheduling. + + Independent tools fire concurrently. Dependent tools wait for + their prerequisites. Falls back to sequential on cyclic deps. + """ + from ._scheduler import execute_local_tools_async + if verbose: print(f" _execute_tool_calls: Processing {len(tool_calls)} tool calls") - # Record single assistant message with ALL tool calls (OpenAI format) + # Record assistant message with all tool calls (OpenAI format). messages.append({"role": "assistant", "tool_calls": list(tool_calls)}) - for i, tc in enumerate(tool_calls): - fn_name = tc["function"]["name"] - fn_args_str = tc["function"]["arguments"] - - if verbose: - print(f" Tool {i + 1}/{len(tool_calls)}: {fn_name}") - - try: - fn_args = json.loads(fn_args_str) - except json.JSONDecodeError: - fn_args = {} - - try: - result = await tool_handler.exec(fn_name, fn_args) - tool_results.append({"name": fn_name, "result": result, "step": step}) - tools_called.append(fn_name) - messages.append({"role": "tool", "tool_call_id": tc["id"], "content": str(result)}) - - if verbose: - print(f" Tool {fn_name} executed successfully: {str(result)[:50]}...") - except Exception as e: - error_result = {"error": str(e), "name": fn_name, "step": step} - tool_results.append(error_result) - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": f"Error: {str(e)}", - } - ) - - if verbose: - print(f" Tool {fn_name} failed with error: {e}") - print(f" Error type: {type(e).__name__}") + await execute_local_tools_async( + tool_calls, + tool_handler, + messages, + tool_results, + tools_called, + step, + verbose=verbose, + ) def _execute_tool_calls_sync( self, @@ -1295,34 +1207,20 @@ def _execute_tool_calls_sync( tools_called: list[str], step: int, ): - """Execute tool calls synchronously.""" - # Record single assistant message with ALL tool calls (OpenAI format) - messages.append({"role": "assistant", "tool_calls": list(tool_calls)}) - - for tc in tool_calls: - fn_name = tc["function"]["name"] - fn_args_str = tc["function"]["arguments"] + """Execute tool calls synchronously with dependency-aware ordering.""" + from ._scheduler import execute_local_tools_sync - try: - fn_args = json.loads(fn_args_str) - except json.JSONDecodeError: - fn_args = {} + # Record assistant message with all tool calls (OpenAI format). + messages.append({"role": "assistant", "tool_calls": list(tool_calls)}) - try: - result = tool_handler.exec_sync(fn_name, fn_args) - tool_results.append({"name": fn_name, "result": result, "step": step}) - tools_called.append(fn_name) - messages.append({"role": "tool", "tool_call_id": tc["id"], "content": str(result)}) - except Exception as e: - error_result = {"error": str(e), "name": fn_name, "step": step} - tool_results.append(error_result) - messages.append( - { - "role": "tool", - "tool_call_id": tc["id"], - "content": f"Error: {str(e)}", - } - ) + execute_local_tools_sync( + tool_calls, + tool_handler, + messages, + tool_results, + tools_called, + step, + ) def _accumulate_tool_calls(self, deltas, acc: list[ToolCall]) -> None: """Accumulate streaming tool call deltas.""" @@ -1350,17 +1248,15 @@ def _accumulate_tool_calls(self, deltas, acc: list[ToolCall]) -> None: @staticmethod def _mk_kwargs(mc: _ModelConfig) -> Dict[str, Any]: - """Convert model config to kwargs for client call.""" + """Convert model config to kwargs for the API call.""" from ..._utils import is_given from ...lib._parsing import type_to_response_format_param - d = asdict(mc) - d.pop("id", None) # Remove id since it's passed separately - d.pop("model_list", None) # Remove model_list since it's not an API parameter + kwargs = dict(mc.api_kwargs) - # Convert Pydantic model to dict schema if needed - if "response_format" in d and d["response_format"] is not None: - converted = type_to_response_format_param(d["response_format"]) - d["response_format"] = converted if is_given(converted) else None + # Convert Pydantic model class to dict schema if needed. + if "response_format" in kwargs and kwargs["response_format"] is not None: + converted = type_to_response_format_param(kwargs["response_format"]) + kwargs["response_format"] = converted if is_given(converted) else None - return {k: v for k, v in d.items() if v is not None} + return {k: v for k, v in kwargs.items() if v is not None} diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py new file mode 100644 index 0000000..99ba7d0 --- /dev/null +++ b/tests/test_local_scheduler.py @@ -0,0 +1,278 @@ +# ============================================================================== +# © 2025 Dedalus Labs, Inc. and affiliates +# Licensed under MIT +# github.com/dedalus-labs/dedalus-sdk-python/LICENSE +# ============================================================================== + +"""Tests for SDK-side dependency-aware local tool scheduler.""" + +from __future__ import annotations + +import asyncio +import json +import time +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from dedalus_labs.lib.runner._scheduler import ( + execute_local_tools_async, + execute_local_tools_sync, +) + + +# --- Test helpers --- + + +def _make_tool_call( + call_id: str, + name: str, + args: Dict[str, Any] | None = None, + dependencies: list[str] | None = None, +) -> Dict[str, Any]: + """Build a tool call dict matching the format from core.py.""" + return { + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(args or {}), + }, + "dependencies": dependencies or [], + } + + +def _make_async_handler( + results: Dict[str, Any] | None = None, + delay: float = 0.0, +) -> AsyncMock: + """Create a mock tool handler with async exec.""" + handler = MagicMock() + results = results or {} + + async def exec_fn(name: str, args: Dict[str, Any]) -> Any: + if delay: + await asyncio.sleep(delay) + if name in results: + return results[name] + return f"result_{name}" + + handler.exec = AsyncMock(side_effect=exec_fn) + return handler + + +def _make_sync_handler(results: Dict[str, Any] | None = None) -> MagicMock: + """Create a mock tool handler with sync exec_sync.""" + handler = MagicMock() + results = results or {} + + def exec_sync_fn(name: str, args: Dict[str, Any]) -> Any: + if name in results: + return results[name] + return f"result_{name}" + + handler.exec_sync = MagicMock(side_effect=exec_sync_fn) + return handler + + +# --- Async tests --- + + +@pytest.mark.asyncio +async def test_async_independent_tools_run_in_parallel(): + """Independent tools with no deps fire concurrently.""" + handler = _make_async_handler(delay=0.05) + calls = [ + _make_tool_call("a", "fetch_a"), + _make_tool_call("b", "fetch_b"), + _make_tool_call("c", "fetch_c"), + ] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + start = time.monotonic() + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + elapsed = time.monotonic() - start + + assert len(tools_called) == 3 + # Parallel: should complete in ~1x delay, not 3x. + assert elapsed < 0.15 + + +@pytest.mark.asyncio +async def test_async_chain_respects_ordering(): + """b depends on a: a executes before b.""" + execution_order: list[str] = [] + handler = MagicMock() + + async def tracking_exec(name: str, args: Dict[str, Any]) -> str: + execution_order.append(name) + return f"done_{name}" + + handler.exec = AsyncMock(side_effect=tracking_exec) + + calls = [ + _make_tool_call("a", "fetch"), + _make_tool_call("b", "transform", dependencies=["a"]), + ] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + assert execution_order == ["fetch", "transform"] + + +@pytest.mark.asyncio +async def test_async_diamond_dependency(): + """Diamond: a,b independent, c depends on both.""" + execution_order: list[str] = [] + handler = MagicMock() + + async def tracking_exec(name: str, args: Dict[str, Any]) -> str: + execution_order.append(name) + return f"done_{name}" + + handler.exec = AsyncMock(side_effect=tracking_exec) + + calls = [ + _make_tool_call("a", "fetch_a"), + _make_tool_call("b", "fetch_b"), + _make_tool_call("c", "combine", dependencies=["a", "b"]), + ] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + # a and b execute first (any order), then c. + assert execution_order[-1] == "combine" + assert set(execution_order[:2]) == {"fetch_a", "fetch_b"} + + +@pytest.mark.asyncio +async def test_async_cycle_falls_back_to_sequential(): + """Cyclic deps don't crash. Falls back to sequential execution.""" + handler = _make_async_handler() + calls = [ + _make_tool_call("a", "fetch", dependencies=["b"]), + _make_tool_call("b", "fetch", dependencies=["a"]), + ] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + assert len(tools_called) == 2 + + +@pytest.mark.asyncio +async def test_async_records_messages_correctly(): + """Tool results are recorded as messages in correct format.""" + handler = _make_async_handler(results={"my_tool": "hello world"}) + calls = [_make_tool_call("call_1", "my_tool", args={"x": 1})] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + + # Scheduler only appends tool result messages (assistant message is caller's job). + assert messages[0]["role"] == "tool" + assert messages[0]["tool_call_id"] == "call_1" + assert messages[0]["content"] == "hello world" + + +@pytest.mark.asyncio +async def test_async_tool_error_recorded(): + """Tool execution errors are recorded in messages, not raised.""" + handler = MagicMock() + + async def failing_exec(name: str, args: Dict[str, Any]) -> str: + raise RuntimeError("boom") + + handler.exec = AsyncMock(side_effect=failing_exec) + calls = [_make_tool_call("a", "bad_tool")] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + assert "error" in tool_results[0] + assert messages[0]["content"] == "Error: boom" + + +@pytest.mark.asyncio +async def test_async_empty_calls(): + """Empty call list is a no-op.""" + handler = _make_async_handler() + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async([], handler, messages, tool_results, tools_called, step=1) + assert messages == [] + + +@pytest.mark.asyncio +async def test_async_unknown_deps_ignored(): + """Dependencies referencing ids not in this batch are filtered.""" + handler = _make_async_handler() + calls = [_make_tool_call("a", "fetch", dependencies=["nonexistent"])] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + await execute_local_tools_async(calls, handler, messages, tool_results, tools_called, step=1) + assert len(tools_called) == 1 + + +# --- Sync tests --- + + +def test_sync_chain_respects_ordering(): + """Sync path: b depends on a, executes in correct order.""" + execution_order: list[str] = [] + handler = MagicMock() + + def tracking_sync(name: str, args: Dict[str, Any]) -> str: + execution_order.append(name) + return f"done_{name}" + + handler.exec_sync = MagicMock(side_effect=tracking_sync) + + calls = [ + _make_tool_call("a", "fetch"), + _make_tool_call("b", "transform", dependencies=["a"]), + ] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + execute_local_tools_sync(calls, handler, messages, tool_results, tools_called, step=1) + assert execution_order == ["fetch", "transform"] + + +def test_sync_cycle_falls_back(): + """Sync path: cycle detected, falls back to sequential.""" + handler = _make_sync_handler() + calls = [ + _make_tool_call("a", "fetch", dependencies=["b"]), + _make_tool_call("b", "fetch", dependencies=["a"]), + ] + messages: list = [] + tool_results: list = [] + tools_called: list = [] + + execute_local_tools_sync(calls, handler, messages, tool_results, tools_called, step=1) + assert len(tools_called) == 2 + + +def test_sync_empty_calls(): + """Sync path: empty call list is a no-op.""" + handler = _make_sync_handler() + messages: list = [] + execute_local_tools_sync([], handler, messages, [], [], step=1) + assert messages == []