From 0aa12e7c5422115a3dc4d1cd539fb4d9f757c8f6 Mon Sep 17 00:00:00 2001 From: isaacbmiller Date: Sat, 28 Feb 2026 16:08:13 -0500 Subject: [PATCH] perf: run sync modules in bounded thread pool executor Move sync forward() and batch() calls off the event loop into a dedicated ThreadPoolExecutor with context variable propagation. Keeps the existing hasattr(instance, 'aforward') check for async detection. Configurable via --sync-workers flag or server.sync_worker_threads in dspy.config.yaml. Co-authored-by: factory-droid[bot] <138933559+factory-droid[bot]@users.noreply.github.com> --- src/dspy_cli/commands/serve.py | 12 ++++- src/dspy_cli/server/app.py | 9 ++++ src/dspy_cli/server/execution.py | 7 ++- src/dspy_cli/server/executor.py | 51 ++++++++++++++++++ src/dspy_cli/server/runner.py | 10 ++++ tests/test_commands_smoke.py | 1 + tests/test_executor.py | 91 ++++++++++++++++++++++++++++++++ 7 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 src/dspy_cli/server/executor.py create mode 100644 tests/test_executor.py diff --git a/src/dspy_cli/commands/serve.py b/src/dspy_cli/commands/serve.py index 9c45798..fc55881 100644 --- a/src/dspy_cli/commands/serve.py +++ b/src/dspy_cli/commands/serve.py @@ -102,7 +102,13 @@ def _exec_clean(target_python: Path, args: list[str]) -> NoReturn: default=False, help="Enable API authentication via DSPY_API_KEY (default: disabled)", ) -def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, system, mcp, auth): +@click.option( + "--sync-workers", + default=None, + type=click.IntRange(1, 200), + help="Number of threads for sync module execution (default: min(32, cpu+4))", +) +def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, system, mcp, auth, sync_workers): """Start an HTTP API server that exposes your DSPy programs. This command: @@ -127,6 +133,7 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy openapi_format=openapi_format, mcp=mcp, auth=auth, + sync_workers=sync_workers, ) return @@ -192,6 +199,8 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy args.append("--mcp") if auth: args.append("--auth") + if sync_workers is not None: + args.extend(["--sync-workers", str(sync_workers)]) _exec_clean(target_python, args) else: @@ -205,4 +214,5 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy openapi_format=openapi_format, mcp=mcp, auth=auth, + sync_workers=sync_workers, ) diff --git a/src/dspy_cli/server/app.py b/src/dspy_cli/server/app.py index e017871..58bf1f3 100644 --- a/src/dspy_cli/server/app.py +++ b/src/dspy_cli/server/app.py @@ -15,6 +15,7 @@ from dspy_cli.discovery import discover_modules from dspy_cli.discovery.gateway_finder import get_gateways_for_module, is_cron_gateway from dspy_cli.gateway import APIGateway, IdentityGateway +from dspy_cli.server.executor import init_executor, shutdown_executor, DEFAULT_SYNC_WORKERS from dspy_cli.server.logging import setup_logging from dspy_cli.server.metrics import get_all_metrics, get_program_metrics_cached from dspy_cli.server.routes import create_program_routes @@ -31,6 +32,7 @@ def create_app( logs_dir: Path, enable_ui: bool = True, enable_auth: bool = False, + sync_workers: int | None = None, ) -> FastAPI: """Create and configure the FastAPI application. @@ -41,6 +43,7 @@ def create_app( logs_dir: Directory for log files enable_ui: Whether to enable the web UI (always True, kept for compatibility) enable_auth: Whether to enable API authentication via DSPY_API_KEY + sync_workers: Number of threads for sync module execution (overrides config) Returns: Configured FastAPI application @@ -48,6 +51,10 @@ def create_app( # Setup logging setup_logging() + # Initialize bounded executor for sync module execution + worker_count = sync_workers or config.get("server", {}).get("sync_worker_threads") or DEFAULT_SYNC_WORKERS + init_executor(max_workers=worker_count) + # Create FastAPI app app = FastAPI( title="DSPy API", @@ -332,6 +339,8 @@ async def lifespan(app: FastAPI): except Exception as e: logger.warning(f"Gateway shutdown error: {e}") + shutdown_executor() + def _create_lm_instance(model_config: Dict) -> dspy.LM: """Create a DSPy LM instance from configuration. diff --git a/src/dspy_cli/server/execution.py b/src/dspy_cli/server/execution.py index 8279392..2a7d4b8 100644 --- a/src/dspy_cli/server/execution.py +++ b/src/dspy_cli/server/execution.py @@ -10,6 +10,7 @@ import dspy from dspy_cli.discovery import DiscoveredModule +from dspy_cli.server.executor import run_sync_in_executor from dspy_cli.server.logging import log_inference logger = logging.getLogger(__name__) @@ -280,7 +281,7 @@ async def execute_pipeline( if hasattr(instance, 'aforward'): result = await instance.acall(**inputs) else: - result = instance(**inputs) + result = await run_sync_in_executor(instance, **inputs) output = _normalize_output(result, module) duration_ms = (time.time() - start_time) * 1000 @@ -393,7 +394,9 @@ async def execute_pipeline_batch( if max_errors is not None: batch_kwargs["max_errors"] = max_errors - batch_result = instance.batch(examples, **batch_kwargs) + batch_result = await run_sync_in_executor( + instance.batch, examples, **batch_kwargs + ) if isinstance(batch_result, tuple) and len(batch_result) == 3: successful, failed_examples, exceptions = batch_result diff --git a/src/dspy_cli/server/executor.py b/src/dspy_cli/server/executor.py new file mode 100644 index 0000000..ad5ffb2 --- /dev/null +++ b/src/dspy_cli/server/executor.py @@ -0,0 +1,51 @@ +"""Bounded thread pool executor for sync DSPy module execution. + +Sync forward() calls are dispatched here so they don't block the async +event loop. Context variables (including dspy.context overrides) are +propagated into the worker thread automatically. +""" + +import asyncio +import contextvars +import functools +import logging +import os +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + +_executor: Optional[ThreadPoolExecutor] = None + +DEFAULT_SYNC_WORKERS = min(32, (os.cpu_count() or 1) + 4) + + +def init_executor(max_workers: Optional[int] = None) -> ThreadPoolExecutor: + """Create the process-wide bounded executor.""" + global _executor + if _executor is not None: + _executor.shutdown(wait=False) + + workers = max_workers or DEFAULT_SYNC_WORKERS + _executor = ThreadPoolExecutor(max_workers=workers, thread_name_prefix="dspy-sync") + logger.info(f"Initialized sync executor with {workers} worker threads") + return _executor + + +def shutdown_executor() -> None: + """Shut down the executor, waiting for pending work.""" + global _executor + if _executor is not None: + _executor.shutdown(wait=True) + _executor = None + + +async def run_sync_in_executor(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Run a sync callable in the bounded executor with context propagation. + + Falls back to the default executor if init_executor() hasn't been called. + """ + loop = asyncio.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, fn, *args, **kwargs) + return await loop.run_in_executor(_executor, func_call) diff --git a/src/dspy_cli/server/runner.py b/src/dspy_cli/server/runner.py index b1636c3..2379c55 100644 --- a/src/dspy_cli/server/runner.py +++ b/src/dspy_cli/server/runner.py @@ -20,6 +20,7 @@ ENV_ENABLE_MCP = "DSPY_CLI_ENABLE_MCP" ENV_LOGS_DIR = "DSPY_CLI_LOGS_DIR" ENV_AUTH_ENABLED = "DSPY_CLI_AUTH_ENABLED" +ENV_SYNC_WORKERS = "DSPY_CLI_SYNC_WORKERS" def _maybe_mount_mcp(app, enable: bool, *, path: str = MCP_DEFAULT_PATH, notify=None) -> bool: @@ -86,6 +87,8 @@ def create_app_instance(): logs_dir = os.environ.get(ENV_LOGS_DIR, "./logs") enable_mcp = os.environ.get(ENV_ENABLE_MCP, "false").lower() == "true" enable_auth = os.environ.get(ENV_AUTH_ENABLED, "false").lower() == "true" + sync_workers_str = os.environ.get(ENV_SYNC_WORKERS) + sync_workers = int(sync_workers_str) if sync_workers_str else None # Validate project structure if not validate_project_structure(): @@ -118,6 +121,7 @@ def create_app_instance(): logs_dir=logs_path, enable_ui=True, enable_auth=enable_auth, + sync_workers=sync_workers, ) # Mount MCP if enabled @@ -135,6 +139,7 @@ def main( openapi_format: str = "json", mcp: bool = False, auth: bool = False, + sync_workers: int | None = None, ): """Main server execution logic. @@ -192,6 +197,7 @@ def main( logs_dir=logs_path, enable_ui=True, enable_auth=auth, + sync_workers=sync_workers, ) # Mount MCP if enabled @@ -275,6 +281,8 @@ def notify_cli(msg: str, level: str = "info"): os.environ[ENV_LOGS_DIR] = str(logs_path) os.environ[ENV_ENABLE_MCP] = str(mcp).lower() os.environ[ENV_AUTH_ENABLED] = str(auth).lower() + if sync_workers is not None: + os.environ[ENV_SYNC_WORKERS] = str(sync_workers) # Get project root and src directory for watching project_root = Path.cwd() @@ -318,6 +326,7 @@ def notify_cli(msg: str, level: str = "info"): parser.add_argument("--openapi-format", choices=["json", "yaml"], default="json") parser.add_argument("--mcp", action="store_true", help="Enable MCP server at /mcp") parser.add_argument("--auth", action="store_true", help="Enable API authentication") + parser.add_argument("--sync-workers", type=int, default=None, help="Number of sync worker threads") args = parser.parse_args() main( @@ -329,4 +338,5 @@ def notify_cli(msg: str, level: str = "info"): openapi_format=args.openapi_format, mcp=args.mcp, auth=args.auth, + sync_workers=args.sync_workers, ) diff --git a/tests/test_commands_smoke.py b/tests/test_commands_smoke.py index d508541..53cfb73 100644 --- a/tests/test_commands_smoke.py +++ b/tests/test_commands_smoke.py @@ -123,6 +123,7 @@ def fake_runner_main(**kwargs): "openapi_format": "yaml", "mcp": False, "auth": False, + "sync_workers": None, } diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..5c3fa2d --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,91 @@ +"""Tests for the bounded executor and context propagation.""" + +import asyncio +import contextvars + +import pytest + +from dspy_cli.server.executor import ( + init_executor, + run_sync_in_executor, + shutdown_executor, +) + + +@pytest.fixture(autouse=True) +def _clean_executor(): + """Ensure each test gets a fresh executor.""" + yield + shutdown_executor() + + +class TestContextPropagation: + + def test_contextvar_propagates_to_executor_thread(self): + cv = contextvars.ContextVar("test_cv", default="UNSET") + init_executor(max_workers=2) + + def read_cv(): + return cv.get() + + async def run(): + cv.set("per-request-value") + return await run_sync_in_executor(read_cv) + + result = asyncio.get_event_loop().run_until_complete(run()) + assert result == "per-request-value" + + def test_concurrent_requests_see_own_context(self): + cv = contextvars.ContextVar("test_cv", default="UNSET") + init_executor(max_workers=4) + + results = {} + + def read_cv(): + import time + time.sleep(0.05) + return cv.get() + + async def make_request(name: str, value: str): + cv.set(value) + results[name] = await run_sync_in_executor(read_cv) + + async def run(): + await asyncio.gather( + make_request("a", "alpha"), + make_request("b", "beta"), + make_request("c", "gamma"), + ) + + asyncio.get_event_loop().run_until_complete(run()) + assert results == {"a": "alpha", "b": "beta", "c": "gamma"} + + def test_dspy_context_lm_propagates(self): + import dspy + + init_executor(max_workers=2) + + def read_lm(): + return dspy.settings.lm + + async def run(): + sentinel = object() + with dspy.context(lm=sentinel): + result = await run_sync_in_executor(read_lm) + return result, sentinel + + result, sentinel = asyncio.get_event_loop().run_until_complete(run()) + assert result is sentinel + + def test_fallback_without_init(self): + cv = contextvars.ContextVar("test_cv", default="UNSET") + + def read_cv(): + return cv.get() + + async def run(): + cv.set("fallback-value") + return await run_sync_in_executor(read_cv) + + result = asyncio.get_event_loop().run_until_complete(run()) + assert result == "fallback-value"