-
Notifications
You must be signed in to change notification settings - Fork 59
Add a SP e2e test. #1209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add a SP e2e test. #1209
Changes from 4 commits
d77baf9
82ab9eb
8653c01
f257cb7
e9ba556
7afe8fe
59899b0
dab3ec6
decf212
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import os | ||
| import time | ||
| from dataclasses import asdict | ||
|
|
||
| import pytest | ||
| from vllm import LLM, EngineArgs, SamplingParams | ||
| from vllm.config import CompilationConfig | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def setup_new_model_design(): | ||
| os.environ['MODEL_IMPL_TYPE'] = 'vllm' | ||
| # os.environ['SKIP_JAX_PRECOMPILE'] = '1' | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def test_prompts(): | ||
| """Simple test prompts for data parallelism testing.""" | ||
| return [ | ||
| # having a long prompt to trigger a edge case. | ||
| "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Nine for Mortal Men doomed to die, One for the Dark Lord on his dark throne In the Land of Mordor where the Shadows lie. One Ring to rule them all, One Ring to find them, One Ring to bring them all and in the darkness bind them In the Land of Mordor where the Shadows lie.", | ||
| "Hello, my name is", | ||
| "The capital of France is", | ||
| "The colors of the rainbow are", | ||
| "The future of AI is", | ||
| "The president of the United States is", | ||
| "How many players are on a standard soccer team?", | ||
| "In Greek mythology, who is the god of the sea?", | ||
| "What is the capital of Australia?", | ||
| "What is the largest planet in our solar system?", | ||
| "Who developed the theory of general relativity?", | ||
| ] | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sampling_params(): | ||
| """Standard sampling parameters for testing.""" | ||
| return SamplingParams( | ||
| temperature=0.0, | ||
| max_tokens=32, | ||
| ignore_eos=True, | ||
| logprobs=1, | ||
| ) | ||
|
|
||
|
|
||
| def _run_inference_with_config(model_name: str, | ||
| test_prompts: list, | ||
| sampling_params: SamplingParams, | ||
| tensor_parallel_size: int = 1, | ||
| additional_config: dict = {}, | ||
| kv_cache_dtype: str = "auto", | ||
| enable_prefix_caching: bool = False, | ||
| async_scheduling: bool = False) -> list: | ||
| """Helper function to run inference with specified configuration.""" | ||
|
|
||
| # Create LLM args using parser-based approach similar to offline_inference.py | ||
| compilation_config = CompilationConfig(pass_config={ | ||
| "enable_sequence_parallelism": True, | ||
| }, ) | ||
| engine_args = EngineArgs( | ||
| model=model_name, | ||
| max_model_len=128, | ||
| compilation_config=compilation_config, | ||
| tensor_parallel_size=tensor_parallel_size, | ||
| gpu_memory_utilization=0.98, | ||
| max_num_batched_tokens=128, | ||
| max_num_seqs=16, | ||
| enable_prefix_caching=enable_prefix_caching, | ||
| additional_config=additional_config, | ||
| kv_cache_dtype=kv_cache_dtype, | ||
| async_scheduling=async_scheduling, | ||
| ) | ||
|
|
||
| engine_args_dict = asdict(engine_args) | ||
| llm = LLM(**engine_args_dict) | ||
|
|
||
| try: | ||
vanbasten23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| outputs = llm.generate(test_prompts, sampling_params) | ||
| return outputs | ||
| finally: | ||
| del llm | ||
| # Wait for TPUs to be released | ||
| time.sleep(5) | ||
|
|
||
|
|
||
| def test_model_sequence_parallelism( | ||
| test_prompts: list, | ||
| sampling_params: SamplingParams, | ||
| ): | ||
| # Use Llama 1B for this test | ||
| test_model = "Qwen/Qwen2.5-32B" | ||
|
|
||
| # Test with data parallelism enabled | ||
| outputs = _run_inference_with_config( | ||
| model_name=test_model, | ||
| test_prompts=test_prompts, | ||
| sampling_params=sampling_params, | ||
| tensor_parallel_size=8, | ||
| async_scheduling=True, | ||
| ) | ||
|
|
||
| # Verify we got outputs for all prompts | ||
| assert len(outputs) == len(test_prompts) | ||
|
|
||
| # Verify each output has generated text | ||
| for output in outputs: | ||
| assert len(output.outputs) > 0 | ||
| assert len(output.outputs[0].text.strip()) > 0 | ||
| print(f"Output: {output.outputs[0].text.strip()}") | ||
|
|
||
| print( | ||
| f"✓ Model sequence parallelism test passed with {len(outputs)} outputs" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,8 +73,14 @@ def get_input_sharding(self, x: torchax.tensor.Tensor): | |
| token_num = x.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: | ||
| logger.info( | ||
|
||
| f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return input sharding {self.input_sharding}." | ||
| ) | ||
| return self.input_sharding | ||
| else: | ||
| logger.info( | ||
| "SP is enabled, but token_num // self.mesh.shape[\"model\"] < TPU_SECOND_LAST_MINOR, return input sharding None." | ||
| ) | ||
| return None | ||
| return self.input_sharding | ||
|
|
||
|
|
@@ -83,8 +89,14 @@ def get_output_sharding(self, x: torchax.tensor.Tensor): | |
| token_num = x.shape[0] | ||
| # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR | ||
| if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: | ||
| logger.info( | ||
vanbasten23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding {self.output_sharding}." | ||
| ) | ||
| return self.output_sharding | ||
| else: | ||
| logger.info( | ||
| "SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding None." | ||
| ) | ||
| return None | ||
| return self.output_sharding | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.