diff --git a/delphi/__main__.py b/delphi/__main__.py index d69d7b10..bc1cc20e 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -151,6 +151,7 @@ async def process_cache( max_model_len=run_cfg.explainer_model_max_len, num_gpus=run_cfg.num_gpus, statistics=run_cfg.verbose, + server_port=run_cfg.server_port, ) elif run_cfg.explainer_provider == "openrouter": if ( diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index 7ad8adbf..ac6908a0 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Union +from openai import AsyncOpenAI from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.distributed.parallel_state import ( @@ -51,27 +52,47 @@ def __init__( num_gpus: int = 2, enforce_eager: bool = False, statistics: bool = False, + server_port: int | None = None, ): """Client for offline generation. Models not already present in the on-disk HuggingFace cache will be downloaded. Note that temperature must be increased for best-of-n sampling. + + If server_port is provided, connects to an external vLLM server running on + localhost at the specified port via the OpenAI-compatible API, instead of + loading the model locally. """ super().__init__(model) self.model = model + self.max_model_len = max_model_len self.queue = asyncio.Queue() self.task = None - self.client = LLM( - model=model, - gpu_memory_utilization=max_memory, - enable_prefix_caching=prefix_caching, - tensor_parallel_size=num_gpus, - max_model_len=max_model_len, - enforce_eager=enforce_eager, - ) + self.server_port = server_port + + if server_port is None: + # Local mode: load model in-process + self.client = LLM( + model=model, + gpu_memory_utilization=max_memory, + enable_prefix_caching=prefix_caching, + tensor_parallel_size=num_gpus, + max_model_len=max_model_len, + enforce_eager=enforce_eager, + ) + self.openai_client = None + else: + # Server mode: connect to external vLLM server + self.client = None + self.openai_client = AsyncOpenAI( + base_url=f"http://localhost:{server_port}/v1", + api_key="EMPTY", + ) + self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate) self.tokenizer = AutoTokenizer.from_pretrained(model) self.batch_size = batch_size self.statistics = statistics + self.number_tokens_to_generate = number_tokens_to_generate if self.statistics: self.statistics_path = Path("statistics") @@ -168,12 +189,62 @@ async def generate( """ Enqueue a request and wait for the result. """ + if self.server_port is not None: + # Server mode: use OpenAI-compatible API directly + return await self._generate_server(prompt, **kwargs) + + # Local mode: use batching queue future = asyncio.Future() if self.task is None: self.task = asyncio.create_task(self._process_batches()) await self.queue.put((prompt, future, kwargs)) return await future + async def _generate_server( + self, prompt: Union[str, list[dict[str, str]]], **kwargs + ) -> Response: + """ + Generate using external vLLM server via OpenAI-compatible API. + """ + temperature = kwargs.get("temperature", 0.0) + max_tokens = kwargs.get("max_tokens", self.number_tokens_to_generate) + + # Handle logprobs if requested + logprobs = kwargs.get("logprobs", False) + top_logprobs = kwargs.get("top_logprobs", None) if logprobs else None + + messages = ( + prompt + if isinstance(prompt, list) + else [{"role": "user", "content": prompt}] + ) + + response = await self.openai_client.chat.completions.create( + model=self.model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + logprobs=logprobs, + top_logprobs=top_logprobs, + ) + + text = response.choices[0].message.content or "" + + # Parse logprobs from OpenAI format if present + parsed_logprobs = None + if logprobs and response.choices[0].logprobs: + parsed_logprobs = [] + for token_logprob in response.choices[0].logprobs.content or []: + top_lps = [ + Top_Logprob(token=lp.token, logprob=lp.logprob) + for lp in (token_logprob.top_logprobs or []) + ] + parsed_logprobs.append( + Logprobs(token=token_logprob.token, top_logprobs=top_lps) + ) + + return Response(text=text, logprobs=parsed_logprobs, prompt_logprobs=None) + def _parse_logprobs(self, response): response_tokens = response.outputs[0].token_ids logprobs = response.outputs[0].logprobs @@ -253,10 +324,17 @@ async def close(self): """ Clean up resources when the client is no longer needed. """ - destroy_model_parallel() - destroy_distributed_environment() - del self.client - self.client = None + if self.client is not None: + # Only destroy local model resources in local mode + destroy_model_parallel() + destroy_distributed_environment() + del self.client + self.client = None + + if self.openai_client is not None: + await self.openai_client.close() + self.openai_client = None + if self.task: self.task.cancel() try: diff --git a/delphi/config.py b/delphi/config.py index de806157..4742c7d4 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -140,6 +140,11 @@ class RunConfig(Serializable): """Provider to use for explanation and scoring. Options are 'offline' for local models and 'openrouter' for API calls.""" + server_port: int | None = field(default=None) + """Port for external vLLM server. If set, connects to a vLLM server running on + localhost at this port via OpenAI-compatible API instead of loading the model + locally. Start a server with: vllm serve --port """ + explainer: str = field( choices=["default", "none"], default="default", diff --git a/pyproject.toml b/pyproject.toml index 554e8036..b1ffe558 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ dependencies = [ "anyio>=4.8.0", "faiss-cpu", "asyncer>=0.0.8", - "beartype" + "beartype", + "openai>=1.0.0", ] [project.optional-dependencies]