diff --git a/.buildkite/parallelism/SP.yml b/.buildkite/parallelism/SP.yml index d86fd4cf0..5474147e1 100644 --- a/.buildkite/parallelism/SP.yml +++ b/.buildkite/parallelism/SP.yml @@ -8,7 +8,8 @@ steps: queue: tpu_v6e_queue commands: - | - buildkite-agent meta-data set "SP_CorrectnessTest" "to be added" + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_sequence_parallelism.py - label: "Record correctness test result for SP" key: "record_SP_CorrectnessTest" depends_on: "SP_CorrectnessTest" diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py new file mode 100644 index 000000000..e9f922ff3 --- /dev/null +++ b/tests/e2e/test_sequence_parallelism.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import logging +import os +import time +from dataclasses import asdict + +import pytest +from vllm import LLM, EngineArgs, SamplingParams +from vllm.config import CompilationConfig + + +@pytest.fixture(autouse=True) +def setup_new_model_design(): + os.environ['MODEL_IMPL_TYPE'] = 'vllm' + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + + +@pytest.fixture +def test_prompts(): + """Simple test prompts for data parallelism testing.""" + return [ + # having a long prompt to trigger a edge case. + "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Nine for Mortal Men doomed to die, One for the Dark Lord on his dark throne In the Land of Mordor where the Shadows lie. One Ring to rule them all, One Ring to find them, One Ring to bring them all and in the darkness bind them In the Land of Mordor where the Shadows lie.", + "Hello, my name is", + "The capital of France is", + "The colors of the rainbow are", + "The future of AI is", + "The president of the United States is", + "How many players are on a standard soccer team?", + "In Greek mythology, who is the god of the sea?", + "What is the capital of Australia?", + "What is the largest planet in our solar system?", + "Who developed the theory of general relativity?", + ] + + +@pytest.fixture +def sampling_params(): + """Standard sampling parameters for testing.""" + return SamplingParams( + temperature=0.0, + max_tokens=32, + ignore_eos=True, + logprobs=1, + ) + + +def _run_inference_with_config(model_name: str, + test_prompts: list, + sampling_params: SamplingParams, + tensor_parallel_size: int = 1, + additional_config: dict = {}, + kv_cache_dtype: str = "auto", + enable_prefix_caching: bool = False, + async_scheduling: bool = False) -> list: + """Helper function to run inference with specified configuration.""" + + # Create LLM args using parser-based approach similar to offline_inference.py + compilation_config = CompilationConfig(pass_config={ + "enable_sequence_parallelism": True, + }, ) + engine_args = EngineArgs( + model=model_name, + max_model_len=128, + compilation_config=compilation_config, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=0.98, + max_num_batched_tokens=128, + max_num_seqs=16, + enable_prefix_caching=enable_prefix_caching, + additional_config=additional_config, + kv_cache_dtype=kv_cache_dtype, + async_scheduling=async_scheduling, + ) + + engine_args_dict = asdict(engine_args) + llm = LLM(**engine_args_dict) + + try: + outputs = llm.generate(test_prompts, sampling_params) + return outputs + finally: + del llm + # Wait for TPUs to be released + time.sleep(5) + + +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + +def test_model_sequence_parallelism( + caplog_vllm: pytest.LogCaptureFixture, + test_prompts: list, + sampling_params: SamplingParams, +): + # Use Llama 1B for this test + test_model = "Qwen/Qwen2.5-32B" + caplog_vllm.clear() + + # Set the logging level for the test + with caplog_vllm.at_level( + logging.INFO, + logger="vllm.tpu_inference.layers.vllm.quantization.common"): + + # Test with data parallelism enabled + outputs = _run_inference_with_config( + model_name=test_model, + test_prompts=test_prompts, + sampling_params=sampling_params, + tensor_parallel_size=8, + async_scheduling=True, + ) + + # Verify we got outputs for all prompts + assert len(outputs) == len(test_prompts) + + # Verify each output has generated text + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text.strip()) > 0 + print(f"Output: {output.outputs[0].text.strip()}") + + # caplog.text contains all the captured log output + print(f'xw32 {caplog_vllm.records[0].getMessage()}') + print( + f"✓ Model sequence parallelism test passed with {len(outputs)} outputs" + ) diff --git a/tpu_inference/layers/vllm/quantization/common.py b/tpu_inference/layers/vllm/quantization/common.py index 2b36a795e..20459ff77 100644 --- a/tpu_inference/layers/vllm/quantization/common.py +++ b/tpu_inference/layers/vllm/quantization/common.py @@ -1,7 +1,6 @@ import torchax from jax.sharding import Mesh, PartitionSpec from vllm.config import VllmConfig -from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -13,6 +12,7 @@ from tpu_inference.layers.vllm.linear_common import \ get_model_matmul_fusion_assignment +from tpu_inference.logger import init_logger from tpu_inference.utils import TPU_SECOND_LAST_MINOR # yapf: enable @@ -73,8 +73,14 @@ def get_input_sharding(self, x: torchax.tensor.Tensor): token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: + logger.debug( + f"SP is enabled, returning non-None input sharding {self.input_sharding}. token_num // self.mesh.shape[model] >= TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." + ) return self.input_sharding else: + logger.debug( + f"SP is enabled, returning input sharding None. token_num // self.mesh.shape[model] < TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." + ) return None return self.input_sharding @@ -83,8 +89,14 @@ def get_output_sharding(self, x: torchax.tensor.Tensor): token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: + logger.debug( + f"SP is enabled, returning non-None output sharding {self.output_sharding}. token_num // self.mesh.shape[model] >= TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." + ) return self.output_sharding else: + logger.debug( + f"SP is enabled, return output sharding None. token_num // self.mesh.shape[model] < TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." + ) return None return self.output_sharding