diff --git a/src/dspy_cli/server/routes.py b/src/dspy_cli/server/routes.py index 4e0732a..83c08a5 100644 --- a/src/dspy_cli/server/routes.py +++ b/src/dspy_cli/server/routes.py @@ -1,10 +1,12 @@ """Dynamic route generation for DSPy programs.""" +import asyncio import logging from typing import Any, Dict import dspy from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse from pydantic import create_model from dspy_cli.discovery import DiscoveredModule @@ -12,6 +14,9 @@ from dspy_cli.gateway import APIGateway, IdentityGateway from dspy_cli.server.execution import _convert_dspy_types, execute_pipeline +DEFAULT_MAX_CONCURRENT = 20 +_program_semaphores: Dict[str, asyncio.Semaphore] = {} + logger = logging.getLogger(__name__) @@ -77,8 +82,21 @@ def create_program_routes( else: route_path = f"/{program_name}/{gateway.__class__.__name__}" + max_concurrent = config.get("server", {}).get("max_concurrent_per_program", DEFAULT_MAX_CONCURRENT) + if program_name not in _program_semaphores: + _program_semaphores[program_name] = asyncio.Semaphore(max_concurrent) + sem = _program_semaphores[program_name] + async def run_program(request: request_model): """Execute the DSPy program with given inputs.""" + try: + await asyncio.wait_for(sem.acquire(), timeout=30.0) + except asyncio.TimeoutError: + return JSONResponse( + status_code=429, + content={"detail": f"Too many concurrent requests for '{program_name}'. Try again later."}, + ) + try: pipeline_inputs = gateway.to_pipeline_inputs(request) @@ -100,6 +118,8 @@ async def run_program(request: request_model): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + finally: + sem.release() # Initialize gateway lifecycle gateway.setup()