Skip to content

Commit b95ee39

Browse files
[Disagg] local disagg e2e test (#1237)
1 parent 3cadb34 commit b95ee39

File tree

3 files changed

+279
-2
lines changed

3 files changed

+279
-2
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Single-Host-P-D-disaggregation
2+
# features support matrix
3+
steps:
4+
- label: "Correctness tests for Single-Host P-D disaggregation"
5+
key: "SingleHostPDDisaggregation_CorrectnessTest"
6+
soft_fail: true
7+
agents:
8+
queue: tpu_v6e_queue
9+
commands:
10+
- |
11+
.buildkite/scripts/run_in_docker.sh \
12+
python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_local_disagg.py::test_disaggregated_serving \
13+
/workspace/tpu_inference/tests/e2e/test_local_disagg.py::test_disaggregated_serving_correctness
14+
- label: "Record correctness test result for Single-Host P-D disaggregation"
15+
key: "record_SingleHostPDDisaggregation_CorrectnessTest"
16+
depends_on: "SingleHostPDDisaggregation_CorrectnessTest"
17+
env:
18+
CI_TARGET: "SingleHostPDDisaggregation"
19+
CI_STAGE: "CorrectnessTest"
20+
CI_CATEGORY: "features support matrix"
21+
agents:
22+
queue: cpu
23+
commands:
24+
- |
25+
.buildkite/scripts/record_step_result.sh SingleHostPDDisaggregation_CorrectnessTest

tests/e2e/test_local_disagg.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import os
5+
from dataclasses import asdict
6+
from unittest.mock import patch
7+
8+
import pytest
9+
from vllm import LLM, EngineArgs, SamplingParams
10+
11+
from tpu_inference.core.core_tpu import DisaggEngineCore, DisaggEngineCoreProc
12+
13+
14+
@pytest.fixture
15+
def test_prompts():
16+
"""Simple test prompts for disaggregated serving testing."""
17+
return [
18+
"Hello, my name is",
19+
"The capital of France is",
20+
"The colors of the rainbow are",
21+
"The future of AI is",
22+
"The president of the United States is",
23+
"How many players are on a standard soccer team on the field at one time?",
24+
"In Greek mythology, who is the god of the sea?",
25+
"In what year did the Titanic sink?",
26+
"In which museum is the Mona Lisa displayed?",
27+
"Mount Everest is located in which mountain range?",
28+
"What ancient empire was ruled by Julius Caesar?",
29+
"What are the four fundamental forces of nature?",
30+
'What does "CPU" stand for?',
31+
'What does "HTML" stand for?',
32+
"What is the capital of Australia?",
33+
"What is the chemical symbol for gold?",
34+
"What is the currency of Switzerland?",
35+
"What is the distance from the Earth to the Sun called?",
36+
"What is the freezing point of water in Celsius?",
37+
"What is the hardest known natural substance on Earth?",
38+
"What is the largest planet in our solar system?",
39+
"What is the longest river in the world?",
40+
"What is the main function of the kidneys in the human body?",
41+
"What is the main ingredient in guacamole?",
42+
"What is the most spoken language in the world by number of native speakers?",
43+
"What is the process by which plants use sunlight to create food?",
44+
"Which country is known as the Land of the Rising Sun?",
45+
"Who developed the theory of general relativity?",
46+
'Who directed the original "Star Wars" trilogy?',
47+
"Who is credited with inventing the telephone?",
48+
"Who painted the ceiling of the Sistine Chapel?",
49+
"Who was the first female Prime Minister of the United Kingdom?",
50+
"Who was the first person to walk on the moon?",
51+
"Who wrote the American Declaration of Independence?",
52+
'Who wrote the novel "Pride and Prejudice"?',
53+
]
54+
55+
56+
@pytest.fixture
57+
def sampling_params():
58+
"""Standard sampling parameters for testing."""
59+
return SamplingParams(
60+
temperature=0.0,
61+
max_tokens=32,
62+
ignore_eos=True,
63+
logprobs=1,
64+
)
65+
66+
67+
def test_disaggregated_serving(test_prompts, sampling_params):
68+
"""
69+
Test disaggregated serving end-to-end.
70+
71+
Equivalent to:
72+
PREFILL_SLICES=4 DECODE_SLICES=4 python examples/offline_inference.py \
73+
--model=meta-llama/Meta-Llama-3.1-8B-Instruct --task=generate \
74+
--max_model_len=2048 --tensor_parallel_size 4
75+
"""
76+
# Set environment variables for disaggregated serving
77+
# Using 4 slices for prefill and 4 for decode as requested
78+
# Note: The user example used PREFILL_SLICES=4 DECODE_SLICES=4
79+
# But usually slices are specified as "2x2" or similar if they are TPU topology.
80+
# However, disagg_utils.py _parse_slices handles "4" as well (1D).
81+
# We will stick to the user's example values.
82+
83+
# We need to mock the environment variables for this test
84+
with patch.dict(
85+
os.environ, {
86+
"PREFILL_SLICES": "4",
87+
"DECODE_SLICES": "4",
88+
"SKIP_JAX_PRECOMPILE": "1",
89+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
90+
}):
91+
# Patch the EngineCore classes to use Disagg versions
92+
with patch("vllm.v1.engine.core.EngineCore", DisaggEngineCore), \
93+
patch("vllm.v1.engine.core.EngineCoreProc", DisaggEngineCoreProc):
94+
95+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
96+
97+
engine_args = EngineArgs(
98+
model=model_name,
99+
max_model_len=2048,
100+
tensor_parallel_size=4,
101+
gpu_memory_utilization=0.90,
102+
enforce_eager=False,
103+
)
104+
105+
llm = LLM(**asdict(engine_args))
106+
107+
try:
108+
outputs = llm.generate(test_prompts, sampling_params)
109+
110+
# Verify outputs
111+
assert len(outputs) == len(test_prompts)
112+
for output in outputs:
113+
assert len(output.outputs) > 0
114+
assert len(output.outputs[0].text.strip()) > 0
115+
print(f"Prompt: {output.prompt!r}")
116+
print(f"Generated: {output.outputs[0].text!r}")
117+
118+
finally:
119+
# Clean up if needed, though LLM destructor usually handles it
120+
pass
121+
122+
123+
def _run_inference(model_name: str,
124+
test_prompts: list,
125+
sampling_params: SamplingParams,
126+
tensor_parallel_size: int = 1,
127+
is_disagg: bool = False,
128+
prefill_slices: str = "4",
129+
decode_slices: str = "4") -> list:
130+
"""Helper function to run inference with specified configuration."""
131+
132+
# Define the inner execution logic
133+
def run_inner():
134+
engine_args = EngineArgs(
135+
model=model_name,
136+
max_model_len=2048,
137+
tensor_parallel_size=tensor_parallel_size,
138+
gpu_memory_utilization=0.90,
139+
enforce_eager=False,
140+
)
141+
142+
llm = LLM(**asdict(engine_args))
143+
try:
144+
return llm.generate(test_prompts, sampling_params)
145+
finally:
146+
del llm
147+
# No explicit sleep needed for mock, but good practice if real hardware
148+
pass
149+
150+
if is_disagg:
151+
# Mock environment variables and patch classes for disagg
152+
with patch.dict(
153+
os.environ, {
154+
"PREFILL_SLICES": prefill_slices,
155+
"DECODE_SLICES": decode_slices,
156+
"SKIP_JAX_PRECOMPILE": "1",
157+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
158+
}):
159+
with patch("vllm.v1.engine.core.EngineCore", DisaggEngineCore), \
160+
patch("vllm.v1.engine.core.EngineCoreProc", DisaggEngineCoreProc):
161+
return run_inner()
162+
else:
163+
# Run standard inference
164+
# We still set some env vars to ensure consistent behavior if needed
165+
# but for baseline we want it as standard as possible.
166+
# However, to match the disagg run's potential jax settings:
167+
with patch.dict(os.environ, {
168+
"SKIP_JAX_PRECOMPILE": "1",
169+
"VLLM_XLA_CHECK_RECOMPILATION": "0"
170+
}):
171+
return run_inner()
172+
173+
174+
def test_disaggregated_serving_correctness(test_prompts, sampling_params):
175+
"""
176+
Test that disaggregated serving produces consistent results compared to a baseline.
177+
"""
178+
model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
179+
# Use a smaller subset of prompts for correctness testing
180+
small_prompts = test_prompts[:20]
181+
sampling_params.max_tokens = 16
182+
183+
# Run baseline (standard execution)
184+
# We use tensor_parallel_size=4 to match the disagg resources if we assume
185+
# the user has enough chips, or if we are just mocking.
186+
# Since the original test used tp=4, we stick to it.
187+
print("Running Baseline Inference...")
188+
baseline_outputs = _run_inference(model_name=model_name,
189+
test_prompts=small_prompts,
190+
sampling_params=sampling_params,
191+
tensor_parallel_size=4,
192+
is_disagg=False)
193+
194+
# Run disaggregated inference
195+
print("Running Disaggregated Inference...")
196+
disagg_outputs = _run_inference(model_name=model_name,
197+
test_prompts=small_prompts,
198+
sampling_params=sampling_params,
199+
tensor_parallel_size=4,
200+
is_disagg=True,
201+
prefill_slices="4",
202+
decode_slices="4")
203+
204+
# Compare outputs
205+
assert len(baseline_outputs) == len(disagg_outputs)
206+
207+
text_matches = 0
208+
text_mismatches = 0
209+
token_mismatches = 0
210+
211+
for i, (baseline,
212+
disagg) in enumerate(zip(baseline_outputs, disagg_outputs)):
213+
baseline_text = baseline.outputs[0].text.strip()
214+
disagg_text = disagg.outputs[0].text.strip()
215+
216+
# Check text output
217+
if baseline_text == disagg_text:
218+
text_matches += 1
219+
else:
220+
text_mismatches += 1
221+
print(f"Text mismatch found in prompt {i}:")
222+
print(f" Baseline: {baseline_text}")
223+
print(f" Disagg: {disagg_text}")
224+
225+
# Check log probabilities (tokens) if available
226+
baseline_logprobs = baseline.outputs[0].logprobs
227+
disagg_logprobs = disagg.outputs[0].logprobs
228+
229+
if baseline_logprobs is not None and disagg_logprobs is not None:
230+
assert len(baseline_logprobs) == len(disagg_logprobs), \
231+
f"Logprobs length mismatch: {len(baseline_logprobs)} vs {len(disagg_logprobs)}"
232+
233+
for token_idx, (base_lp, disagg_lp) in enumerate(
234+
zip(baseline_logprobs, disagg_logprobs)):
235+
if base_lp and disagg_lp:
236+
# Compare the top token IDs
237+
base_top_token = list(base_lp.keys())[0]
238+
disagg_top_token = list(disagg_lp.keys())[0]
239+
240+
if base_top_token != disagg_top_token:
241+
token_mismatches += 1
242+
print(
243+
f"Token mismatch in prompt {i}, token {token_idx}:"
244+
)
245+
print(f" Baseline: {base_top_token}")
246+
print(f" Disagg: {disagg_top_token}")
247+
248+
print("✓ Correctness test results:")
249+
print(f" Text: {text_matches} matches, {text_mismatches} mismatches")
250+
print(f" Token mismatches in logprobs: {token_mismatches}")
251+
assert text_mismatches <= 5, f"Found {text_mismatches} text mismatches"
252+
assert token_mismatches <= 40, f"Found {token_mismatches} token mismatches"

tpu_inference/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def device_array(mesh: Mesh, *args, sharding=None, **kwargs) -> jax.Array:
275275

276276
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]:
277277
"""
278-
A wrapper function of vllm.utils.get_hash_fn_by_name to support builtin
278+
A wrapper function of vllm.utils.hashing.get_hash_fn_by_name to support builtin
279279
"""
280280
if hash_fn_name == "builtin":
281281
return hash
282-
return utils.get_hash_fn_by_name(hash_fn_name)
282+
return utils.hashing.get_hash_fn_by_name(hash_fn_name)
283283

284284

285285
def quantize_kv(key: jax.Array, value: jax.Array,

0 commit comments

Comments
 (0)