Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ async def process_cache(
max_model_len=run_cfg.explainer_model_max_len,
num_gpus=run_cfg.num_gpus,
statistics=run_cfg.verbose,
server_port=run_cfg.server_port,
)
elif run_cfg.explainer_provider == "openrouter":
if (
Expand Down
102 changes: 90 additions & 12 deletions delphi/clients/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
from typing import Union

from openai import AsyncOpenAI
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import (
Expand Down Expand Up @@ -51,27 +52,47 @@ def __init__(
num_gpus: int = 2,
enforce_eager: bool = False,
statistics: bool = False,
server_port: int | None = None,
):
"""Client for offline generation. Models not already present in the on-disk
HuggingFace cache will be downloaded. Note that temperature must be increased
for best-of-n sampling.

If server_port is provided, connects to an external vLLM server running on
localhost at the specified port via the OpenAI-compatible API, instead of
loading the model locally.
"""
super().__init__(model)
self.model = model
self.max_model_len = max_model_len
self.queue = asyncio.Queue()
self.task = None
self.client = LLM(
model=model,
gpu_memory_utilization=max_memory,
enable_prefix_caching=prefix_caching,
tensor_parallel_size=num_gpus,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
)
self.server_port = server_port

if server_port is None:
# Local mode: load model in-process
self.client = LLM(
model=model,
gpu_memory_utilization=max_memory,
enable_prefix_caching=prefix_caching,
tensor_parallel_size=num_gpus,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
)
self.openai_client = None
else:
# Server mode: connect to external vLLM server
self.client = None
self.openai_client = AsyncOpenAI(
base_url=f"http://localhost:{server_port}/v1",
api_key="EMPTY",
)

self.sampling_params = SamplingParams(max_tokens=number_tokens_to_generate)
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.batch_size = batch_size
self.statistics = statistics
self.number_tokens_to_generate = number_tokens_to_generate

if self.statistics:
self.statistics_path = Path("statistics")
Expand Down Expand Up @@ -168,12 +189,62 @@ async def generate(
"""
Enqueue a request and wait for the result.
"""
if self.server_port is not None:
# Server mode: use OpenAI-compatible API directly
return await self._generate_server(prompt, **kwargs)

# Local mode: use batching queue
future = asyncio.Future()
if self.task is None:
self.task = asyncio.create_task(self._process_batches())
await self.queue.put((prompt, future, kwargs))
return await future

async def _generate_server(
self, prompt: Union[str, list[dict[str, str]]], **kwargs
) -> Response:
"""
Generate using external vLLM server via OpenAI-compatible API.
"""
temperature = kwargs.get("temperature", 0.0)
max_tokens = kwargs.get("max_tokens", self.number_tokens_to_generate)

# Handle logprobs if requested
logprobs = kwargs.get("logprobs", False)
top_logprobs = kwargs.get("top_logprobs", None) if logprobs else None

messages = (
prompt
if isinstance(prompt, list)
else [{"role": "user", "content": prompt}]
)

response = await self.openai_client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
logprobs=logprobs,
top_logprobs=top_logprobs,
)

text = response.choices[0].message.content or ""

# Parse logprobs from OpenAI format if present
parsed_logprobs = None
if logprobs and response.choices[0].logprobs:
parsed_logprobs = []
for token_logprob in response.choices[0].logprobs.content or []:
top_lps = [
Top_Logprob(token=lp.token, logprob=lp.logprob)
for lp in (token_logprob.top_logprobs or [])
]
parsed_logprobs.append(
Logprobs(token=token_logprob.token, top_logprobs=top_lps)
)

return Response(text=text, logprobs=parsed_logprobs, prompt_logprobs=None)

def _parse_logprobs(self, response):
response_tokens = response.outputs[0].token_ids
logprobs = response.outputs[0].logprobs
Expand Down Expand Up @@ -253,10 +324,17 @@ async def close(self):
"""
Clean up resources when the client is no longer needed.
"""
destroy_model_parallel()
destroy_distributed_environment()
del self.client
self.client = None
if self.client is not None:
# Only destroy local model resources in local mode
destroy_model_parallel()
destroy_distributed_environment()
del self.client
self.client = None

if self.openai_client is not None:
await self.openai_client.close()
self.openai_client = None

if self.task:
self.task.cancel()
try:
Expand Down
5 changes: 5 additions & 0 deletions delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ class RunConfig(Serializable):
"""Provider to use for explanation and scoring. Options are 'offline' for local
models and 'openrouter' for API calls."""

server_port: int | None = field(default=None)
"""Port for external vLLM server. If set, connects to a vLLM server running on
localhost at this port via OpenAI-compatible API instead of loading the model
locally. Start a server with: vllm serve <model> --port <port>"""

explainer: str = field(
choices=["default", "none"],
default="default",
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies = [
"anyio>=4.8.0",
"faiss-cpu",
"asyncer>=0.0.8",
"beartype"
"beartype",
"openai>=1.0.0",
]

[project.optional-dependencies]
Expand Down