Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .buildkite/pipeline_jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,19 @@ steps:
- |
.buildkite/scripts/run_in_docker.sh \
bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py'

- label: "E2E sequence parallelism test"
key: test_17
soft_fail: true
env:
VLLM_LOG_LEVEL: "INFO"
agents:
queue: tpu_v6e_8_queue
commands:
- |
.buildkite/scripts/run_in_docker.sh \
bash -c 'pytest -s -vv tests/e2e/test_sequence_parallelism.py'

# -----------------------------------------------------------------
# NOTIFICATION STEP
# -----------------------------------------------------------------
Expand Down
116 changes: 116 additions & 0 deletions tests/e2e/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
import time
from dataclasses import asdict

import pytest
from vllm import LLM, EngineArgs, SamplingParams
from vllm.config import CompilationConfig


@pytest.fixture(autouse=True)
def setup_new_model_design():
os.environ['MODEL_IMPL_TYPE'] = 'vllm'
# os.environ['SKIP_JAX_PRECOMPILE'] = '1'


@pytest.fixture
def test_prompts():
"""Simple test prompts for data parallelism testing."""
return [
# having a long prompt to trigger a edge case.
"Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Nine for Mortal Men doomed to die, One for the Dark Lord on his dark throne In the Land of Mordor where the Shadows lie. One Ring to rule them all, One Ring to find them, One Ring to bring them all and in the darkness bind them In the Land of Mordor where the Shadows lie.",
"Hello, my name is",
"The capital of France is",
"The colors of the rainbow are",
"The future of AI is",
"The president of the United States is",
"How many players are on a standard soccer team?",
"In Greek mythology, who is the god of the sea?",
"What is the capital of Australia?",
"What is the largest planet in our solar system?",
"Who developed the theory of general relativity?",
]


@pytest.fixture
def sampling_params():
"""Standard sampling parameters for testing."""
return SamplingParams(
temperature=0.0,
max_tokens=32,
ignore_eos=True,
logprobs=1,
)


def _run_inference_with_config(model_name: str,
test_prompts: list,
sampling_params: SamplingParams,
tensor_parallel_size: int = 1,
additional_config: dict = {},
kv_cache_dtype: str = "auto",
enable_prefix_caching: bool = False,
async_scheduling: bool = False) -> list:
"""Helper function to run inference with specified configuration."""

# Create LLM args using parser-based approach similar to offline_inference.py
compilation_config = CompilationConfig(pass_config={
"enable_sequence_parallelism": True,
}, )
engine_args = EngineArgs(
model=model_name,
max_model_len=128,
compilation_config=compilation_config,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=0.98,
max_num_batched_tokens=128,
max_num_seqs=16,
enable_prefix_caching=enable_prefix_caching,
additional_config=additional_config,
kv_cache_dtype=kv_cache_dtype,
async_scheduling=async_scheduling,
)

engine_args_dict = asdict(engine_args)
llm = LLM(**engine_args_dict)

try:
outputs = llm.generate(test_prompts, sampling_params)
return outputs
finally:
del llm
# Wait for TPUs to be released
time.sleep(5)


def test_model_sequence_parallelism(
test_prompts: list,
sampling_params: SamplingParams,
):
# Use Llama 1B for this test
test_model = "Qwen/Qwen2.5-32B"

# Test with data parallelism enabled
outputs = _run_inference_with_config(
model_name=test_model,
test_prompts=test_prompts,
sampling_params=sampling_params,
tensor_parallel_size=8,
async_scheduling=True,
)

# Verify we got outputs for all prompts
assert len(outputs) == len(test_prompts)

# Verify each output has generated text
for output in outputs:
assert len(output.outputs) > 0
assert len(output.outputs[0].text.strip()) > 0
print(f"Output: {output.outputs[0].text.strip()}")

print(
f"✓ Model sequence parallelism test passed with {len(outputs)} outputs"
)
12 changes: 12 additions & 0 deletions tpu_inference/layers/vllm/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ def get_input_sharding(self, x: torchax.tensor.Tensor):
token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our e2e test, should we check the log to see sp is actually enabled? Another way is to check the final optimized graph, but that's more difficult.

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked manually that this line(76) is executed. But I think you meant if we can check it in the test. I'm not sure how we can do that in the test. Let me know if you have some ideas.

Though, I added a long prompt (line24 ""Three Rings...") to ensure token_num//8 >= 8 is triggered. Also the precompilation phase use very large num_tokens so this case is triggered.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we output the log in a file and checked the file's content later?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be logger.debug instead of logger.info?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No luck for now. Somehow I couldn't capture the logger. I agree that this may be the easiest way to test. But if someone write a similar logging string, then this test may not work as intended. Other parallelism seem to have the same issue: it's hard to examine how each layer is sharded in the test.

I've been thinking if there is better way to test. Since SP's main benefits is to reduce memory, we can check with SP if the mem usage is indeed reduced. But there is no jax api that let me check the mem usage.

How about let's merge this pr so that it verifies with "enable_sequence_parallelism=True" and "tensor_parallelism=8" the test runs to completion, since that is the intended way to enable SP. Then when we do integration, we improve the test. Wdyt?

f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return input sharding {self.input_sharding}."
)
return self.input_sharding
else:
logger.info(
"SP is enabled, but token_num // self.mesh.shape[\"model\"] < TPU_SECOND_LAST_MINOR, return input sharding None."
)
return None
return self.input_sharding

Expand All @@ -83,8 +89,14 @@ def get_output_sharding(self, x: torchax.tensor.Tensor):
token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
logger.info(
f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding {self.output_sharding}."
)
return self.output_sharding
else:
logger.info(
"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding None."
)
return None
return self.output_sharding

Expand Down