Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/dspy_cli/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ def _exec_clean(target_python: Path, args: list[str]) -> NoReturn:
default=False,
help="Enable API authentication via DSPY_API_KEY (default: disabled)",
)
def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, system, mcp, auth):
@click.option(
"--sync-workers",
default=None,
type=click.IntRange(1, 200),
help="Number of threads for sync module execution (default: min(32, cpu+4))",
)
def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, system, mcp, auth, sync_workers):
"""Start an HTTP API server that exposes your DSPy programs.

This command:
Expand All @@ -127,6 +133,7 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy
openapi_format=openapi_format,
mcp=mcp,
auth=auth,
sync_workers=sync_workers,
)
return

Expand Down Expand Up @@ -192,6 +199,8 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy
args.append("--mcp")
if auth:
args.append("--auth")
if sync_workers is not None:
args.extend(["--sync-workers", str(sync_workers)])

_exec_clean(target_python, args)
else:
Expand All @@ -205,4 +214,5 @@ def serve(port, host, logs_dir, reload, save_openapi, openapi_format, python, sy
openapi_format=openapi_format,
mcp=mcp,
auth=auth,
sync_workers=sync_workers,
)
9 changes: 9 additions & 0 deletions src/dspy_cli/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dspy_cli.discovery import discover_modules
from dspy_cli.discovery.gateway_finder import get_gateways_for_module, is_cron_gateway
from dspy_cli.gateway import APIGateway, IdentityGateway
from dspy_cli.server.executor import init_executor, shutdown_executor, DEFAULT_SYNC_WORKERS
from dspy_cli.server.logging import setup_logging
from dspy_cli.server.metrics import get_all_metrics, get_program_metrics_cached
from dspy_cli.server.routes import create_program_routes
Expand All @@ -31,6 +32,7 @@ def create_app(
logs_dir: Path,
enable_ui: bool = True,
enable_auth: bool = False,
sync_workers: int | None = None,
) -> FastAPI:
"""Create and configure the FastAPI application.

Expand All @@ -41,13 +43,18 @@ def create_app(
logs_dir: Directory for log files
enable_ui: Whether to enable the web UI (always True, kept for compatibility)
enable_auth: Whether to enable API authentication via DSPY_API_KEY
sync_workers: Number of threads for sync module execution (overrides config)

Returns:
Configured FastAPI application
"""
# Setup logging
setup_logging()

# Initialize bounded executor for sync module execution
worker_count = sync_workers or config.get("server", {}).get("sync_worker_threads") or DEFAULT_SYNC_WORKERS
init_executor(max_workers=worker_count)

# Create FastAPI app
app = FastAPI(
title="DSPy API",
Expand Down Expand Up @@ -332,6 +339,8 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.warning(f"Gateway shutdown error: {e}")

shutdown_executor()


def _create_lm_instance(model_config: Dict) -> dspy.LM:
"""Create a DSPy LM instance from configuration.
Expand Down
7 changes: 5 additions & 2 deletions src/dspy_cli/server/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import dspy

from dspy_cli.discovery import DiscoveredModule
from dspy_cli.server.executor import run_sync_in_executor
from dspy_cli.server.logging import log_inference

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -280,7 +281,7 @@ async def execute_pipeline(
if hasattr(instance, 'aforward'):
result = await instance.acall(**inputs)
else:
result = instance(**inputs)
result = await run_sync_in_executor(instance, **inputs)

output = _normalize_output(result, module)
duration_ms = (time.time() - start_time) * 1000
Expand Down Expand Up @@ -393,7 +394,9 @@ async def execute_pipeline_batch(
if max_errors is not None:
batch_kwargs["max_errors"] = max_errors

batch_result = instance.batch(examples, **batch_kwargs)
batch_result = await run_sync_in_executor(
instance.batch, examples, **batch_kwargs
)

if isinstance(batch_result, tuple) and len(batch_result) == 3:
successful, failed_examples, exceptions = batch_result
Expand Down
51 changes: 51 additions & 0 deletions src/dspy_cli/server/executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Bounded thread pool executor for sync DSPy module execution.

Sync forward() calls are dispatched here so they don't block the async
event loop. Context variables (including dspy.context overrides) are
propagated into the worker thread automatically.
"""

import asyncio
import contextvars
import functools
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, Optional

logger = logging.getLogger(__name__)

_executor: Optional[ThreadPoolExecutor] = None

DEFAULT_SYNC_WORKERS = min(32, (os.cpu_count() or 1) + 4)


def init_executor(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
"""Create the process-wide bounded executor."""
global _executor
if _executor is not None:
_executor.shutdown(wait=False)

workers = max_workers or DEFAULT_SYNC_WORKERS
_executor = ThreadPoolExecutor(max_workers=workers, thread_name_prefix="dspy-sync")
logger.info(f"Initialized sync executor with {workers} worker threads")
return _executor


def shutdown_executor() -> None:
"""Shut down the executor, waiting for pending work."""
global _executor
if _executor is not None:
_executor.shutdown(wait=True)
_executor = None


async def run_sync_in_executor(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Run a sync callable in the bounded executor with context propagation.

Falls back to the default executor if init_executor() hasn't been called.
"""
loop = asyncio.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, fn, *args, **kwargs)
return await loop.run_in_executor(_executor, func_call)
10 changes: 10 additions & 0 deletions src/dspy_cli/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ENV_ENABLE_MCP = "DSPY_CLI_ENABLE_MCP"
ENV_LOGS_DIR = "DSPY_CLI_LOGS_DIR"
ENV_AUTH_ENABLED = "DSPY_CLI_AUTH_ENABLED"
ENV_SYNC_WORKERS = "DSPY_CLI_SYNC_WORKERS"


def _maybe_mount_mcp(app, enable: bool, *, path: str = MCP_DEFAULT_PATH, notify=None) -> bool:
Expand Down Expand Up @@ -86,6 +87,8 @@ def create_app_instance():
logs_dir = os.environ.get(ENV_LOGS_DIR, "./logs")
enable_mcp = os.environ.get(ENV_ENABLE_MCP, "false").lower() == "true"
enable_auth = os.environ.get(ENV_AUTH_ENABLED, "false").lower() == "true"
sync_workers_str = os.environ.get(ENV_SYNC_WORKERS)
sync_workers = int(sync_workers_str) if sync_workers_str else None

# Validate project structure
if not validate_project_structure():
Expand Down Expand Up @@ -118,6 +121,7 @@ def create_app_instance():
logs_dir=logs_path,
enable_ui=True,
enable_auth=enable_auth,
sync_workers=sync_workers,
)

# Mount MCP if enabled
Expand All @@ -135,6 +139,7 @@ def main(
openapi_format: str = "json",
mcp: bool = False,
auth: bool = False,
sync_workers: int | None = None,
):
"""Main server execution logic.

Expand Down Expand Up @@ -192,6 +197,7 @@ def main(
logs_dir=logs_path,
enable_ui=True,
enable_auth=auth,
sync_workers=sync_workers,
)

# Mount MCP if enabled
Expand Down Expand Up @@ -275,6 +281,8 @@ def notify_cli(msg: str, level: str = "info"):
os.environ[ENV_LOGS_DIR] = str(logs_path)
os.environ[ENV_ENABLE_MCP] = str(mcp).lower()
os.environ[ENV_AUTH_ENABLED] = str(auth).lower()
if sync_workers is not None:
os.environ[ENV_SYNC_WORKERS] = str(sync_workers)

# Get project root and src directory for watching
project_root = Path.cwd()
Expand Down Expand Up @@ -318,6 +326,7 @@ def notify_cli(msg: str, level: str = "info"):
parser.add_argument("--openapi-format", choices=["json", "yaml"], default="json")
parser.add_argument("--mcp", action="store_true", help="Enable MCP server at /mcp")
parser.add_argument("--auth", action="store_true", help="Enable API authentication")
parser.add_argument("--sync-workers", type=int, default=None, help="Number of sync worker threads")
args = parser.parse_args()

main(
Expand All @@ -329,4 +338,5 @@ def notify_cli(msg: str, level: str = "info"):
openapi_format=args.openapi_format,
mcp=args.mcp,
auth=args.auth,
sync_workers=args.sync_workers,
)
1 change: 1 addition & 0 deletions tests/test_commands_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def fake_runner_main(**kwargs):
"openapi_format": "yaml",
"mcp": False,
"auth": False,
"sync_workers": None,
}


Expand Down
91 changes: 91 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for the bounded executor and context propagation."""

import asyncio
import contextvars

import pytest

from dspy_cli.server.executor import (
init_executor,
run_sync_in_executor,
shutdown_executor,
)


@pytest.fixture(autouse=True)
def _clean_executor():
"""Ensure each test gets a fresh executor."""
yield
shutdown_executor()


class TestContextPropagation:

def test_contextvar_propagates_to_executor_thread(self):
cv = contextvars.ContextVar("test_cv", default="UNSET")
init_executor(max_workers=2)

def read_cv():
return cv.get()

async def run():
cv.set("per-request-value")
return await run_sync_in_executor(read_cv)

result = asyncio.get_event_loop().run_until_complete(run())
assert result == "per-request-value"

def test_concurrent_requests_see_own_context(self):
cv = contextvars.ContextVar("test_cv", default="UNSET")
init_executor(max_workers=4)

results = {}

def read_cv():
import time
time.sleep(0.05)
return cv.get()

async def make_request(name: str, value: str):
cv.set(value)
results[name] = await run_sync_in_executor(read_cv)

async def run():
await asyncio.gather(
make_request("a", "alpha"),
make_request("b", "beta"),
make_request("c", "gamma"),
)

asyncio.get_event_loop().run_until_complete(run())
assert results == {"a": "alpha", "b": "beta", "c": "gamma"}

def test_dspy_context_lm_propagates(self):
import dspy

init_executor(max_workers=2)

def read_lm():
return dspy.settings.lm

async def run():
sentinel = object()
with dspy.context(lm=sentinel):
result = await run_sync_in_executor(read_lm)
return result, sentinel

result, sentinel = asyncio.get_event_loop().run_until_complete(run())
assert result is sentinel

def test_fallback_without_init(self):
cv = contextvars.ContextVar("test_cv", default="UNSET")

def read_cv():
return cv.get()

async def run():
cv.set("fallback-value")
return await run_sync_in_executor(read_cv)

result = asyncio.get_event_loop().run_until_complete(run())
assert result == "fallback-value"