From d77baf9f258f23bed7863caa27472acde9fd16c8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Mon, 1 Dec 2025 23:36:32 +0000 Subject: [PATCH 1/9] added a sp e2e test Signed-off-by: Xiongfei Wei --- tests/e2e/test_sequence_parallelism.py | 108 +++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/e2e/test_sequence_parallelism.py diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py new file mode 100644 index 000000000..70ff36fa2 --- /dev/null +++ b/tests/e2e/test_sequence_parallelism.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +import time +from dataclasses import asdict + +import pytest +from vllm import LLM, EngineArgs, SamplingParams + + +@pytest.fixture(autouse=True) +def setup_new_model_design(): + os.environ['MODEL_IMPL_TYPE'] = 'vllm' + + +@pytest.fixture +def test_prompts(): + """Simple test prompts for data parallelism testing.""" + return [ + "Hello, my name is", + # "The capital of France is", + # "The colors of the rainbow are", + # "The future of AI is", + # "The president of the United States is", + # "How many players are on a standard soccer team?", + # "In Greek mythology, who is the god of the sea?", + # "What is the capital of Australia?", + # "What is the largest planet in our solar system?", + # "Who developed the theory of general relativity?", + ] + + +@pytest.fixture +def sampling_params(): + """Standard sampling parameters for testing.""" + return SamplingParams( + temperature=0.0, + max_tokens=32, + ignore_eos=True, + logprobs=1, + ) + + +def _run_inference_with_config(model_name: str, + test_prompts: list, + sampling_params: SamplingParams, + tensor_parallel_size: int = 1, + additional_config: dict = {}, + kv_cache_dtype: str = "auto", + enable_prefix_caching: bool = False, + async_scheduling: bool = False) -> list: + """Helper function to run inference with specified configuration.""" + + # Create LLM args using parser-based approach similar to offline_inference.py + engine_args = EngineArgs( + model=model_name, + max_model_len=32, + tensor_parallel_size=tensor_parallel_size, + gpu_memory_utilization=0.98, + max_num_batched_tokens=128, + max_num_seqs=16, + enable_prefix_caching=enable_prefix_caching, + additional_config=additional_config, + kv_cache_dtype=kv_cache_dtype, + async_scheduling=async_scheduling, + ) + + engine_args_dict = asdict(engine_args) + llm = LLM(**engine_args_dict) + + try: + outputs = llm.generate(test_prompts, sampling_params) + return outputs + finally: + del llm + # Wait for TPUs to be released + time.sleep(5) + + +def test_model_sequence_parallelism( + test_prompts: list, + sampling_params: SamplingParams, +): + # Use Llama 1B for this test + test_model = "Qwen/Qwen2.5-32B" + + # Test with data parallelism enabled + outputs = _run_inference_with_config( + model_name=test_model, + test_prompts=test_prompts, + sampling_params=sampling_params, + tensor_parallel_size=8, + async_scheduling=True, + ) + + # Verify we got outputs for all prompts + assert len(outputs) == len(test_prompts) + + # Verify each output has generated text + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text.strip()) > 0 + print(f"Output: {output.outputs[0].text.strip()}") + + print( + f"✓ Model sequence parallelism test passed with {len(outputs)} outputs" + ) From 82ab9ebf189ccf771f788a9af5fd8be06d5ce289 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 2 Dec 2025 00:29:07 +0000 Subject: [PATCH 2/9] i'm able to run the test as pytest -s -vv tests/e2e/test_sequence_parallelism.py Signed-off-by: Xiongfei Wei --- tests/e2e/test_sequence_parallelism.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py index 70ff36fa2..b5de22b66 100644 --- a/tests/e2e/test_sequence_parallelism.py +++ b/tests/e2e/test_sequence_parallelism.py @@ -7,6 +7,7 @@ import pytest from vllm import LLM, EngineArgs, SamplingParams +from vllm.config import CompilationConfig @pytest.fixture(autouse=True) @@ -53,9 +54,13 @@ def _run_inference_with_config(model_name: str, """Helper function to run inference with specified configuration.""" # Create LLM args using parser-based approach similar to offline_inference.py + compilation_config = CompilationConfig(pass_config={ + "enable_sequence_parallelism": True, + }, ) engine_args = EngineArgs( model=model_name, max_model_len=32, + compilation_config=compilation_config, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.98, max_num_batched_tokens=128, From 8653c0124887b49f573b494c54ba8ed649e15e89 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 2 Dec 2025 00:52:49 +0000 Subject: [PATCH 3/9] Added more test cases. Signed-off-by: Xiongfei Wei --- tests/e2e/test_sequence_parallelism.py | 6 ++++-- tpu_inference/layers/vllm/quantization/common.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py index b5de22b66..34971cc0a 100644 --- a/tests/e2e/test_sequence_parallelism.py +++ b/tests/e2e/test_sequence_parallelism.py @@ -13,13 +13,15 @@ @pytest.fixture(autouse=True) def setup_new_model_design(): os.environ['MODEL_IMPL_TYPE'] = 'vllm' + os.environ['SKIP_JAX_PRECOMPILE'] = '1' @pytest.fixture def test_prompts(): """Simple test prompts for data parallelism testing.""" return [ - "Hello, my name is", + "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Nine for Mortal Men doomed to die, One for the Dark Lord on his dark throne In the Land of Mordor where the Shadows lie. One Ring to rule them all, One Ring to find them, One Ring to bring them all and in the darkness bind them In the Land of Mordor where the Shadows lie.", + # "Hello, my name is", # "The capital of France is", # "The colors of the rainbow are", # "The future of AI is", @@ -59,7 +61,7 @@ def _run_inference_with_config(model_name: str, }, ) engine_args = EngineArgs( model=model_name, - max_model_len=32, + max_model_len=128, compilation_config=compilation_config, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.98, diff --git a/tpu_inference/layers/vllm/quantization/common.py b/tpu_inference/layers/vllm/quantization/common.py index 2b36a795e..bc248603f 100644 --- a/tpu_inference/layers/vllm/quantization/common.py +++ b/tpu_inference/layers/vllm/quantization/common.py @@ -73,8 +73,14 @@ def get_input_sharding(self, x: torchax.tensor.Tensor): token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: + logger.info( + f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return input sharding {self.input_sharding}." + ) return self.input_sharding else: + logger.info( + "SP is enabled, but token_num // self.mesh.shape[\"model\"] < TPU_SECOND_LAST_MINOR, return input sharding None." + ) return None return self.input_sharding @@ -83,8 +89,14 @@ def get_output_sharding(self, x: torchax.tensor.Tensor): token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: + logger.info( + f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding {self.output_sharding}." + ) return self.output_sharding else: + logger.info( + "SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding None." + ) return None return self.output_sharding From f257cb79f7e562cb9079a1fa364afa0ecc7066b6 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 2 Dec 2025 01:07:03 +0000 Subject: [PATCH 4/9] Add sp e2e test to the CI. Signed-off-by: Xiongfei Wei --- .buildkite/pipeline_jax.yml | 13 +++++++++++++ tests/e2e/test_sequence_parallelism.py | 23 ++++++++++++----------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index 42032735b..fbcf4933e 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -262,6 +262,19 @@ steps: - | .buildkite/scripts/run_in_docker.sh \ bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' + + - label: "E2E sequence parallelism test" + key: test_17 + soft_fail: true + env: + VLLM_LOG_LEVEL: "INFO" + agents: + queue: tpu_v6e_8_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + bash -c 'pytest -s -vv tests/e2e/test_sequence_parallelism.py' + # ----------------------------------------------------------------- # NOTIFICATION STEP # ----------------------------------------------------------------- diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py index 34971cc0a..07c6412da 100644 --- a/tests/e2e/test_sequence_parallelism.py +++ b/tests/e2e/test_sequence_parallelism.py @@ -13,24 +13,25 @@ @pytest.fixture(autouse=True) def setup_new_model_design(): os.environ['MODEL_IMPL_TYPE'] = 'vllm' - os.environ['SKIP_JAX_PRECOMPILE'] = '1' + # os.environ['SKIP_JAX_PRECOMPILE'] = '1' @pytest.fixture def test_prompts(): """Simple test prompts for data parallelism testing.""" return [ + # having a long prompt to trigger a edge case. "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Nine for Mortal Men doomed to die, One for the Dark Lord on his dark throne In the Land of Mordor where the Shadows lie. One Ring to rule them all, One Ring to find them, One Ring to bring them all and in the darkness bind them In the Land of Mordor where the Shadows lie.", - # "Hello, my name is", - # "The capital of France is", - # "The colors of the rainbow are", - # "The future of AI is", - # "The president of the United States is", - # "How many players are on a standard soccer team?", - # "In Greek mythology, who is the god of the sea?", - # "What is the capital of Australia?", - # "What is the largest planet in our solar system?", - # "Who developed the theory of general relativity?", + "Hello, my name is", + "The capital of France is", + "The colors of the rainbow are", + "The future of AI is", + "The president of the United States is", + "How many players are on a standard soccer team?", + "In Greek mythology, who is the god of the sea?", + "What is the capital of Australia?", + "What is the largest planet in our solar system?", + "Who developed the theory of general relativity?", ] From e9ba556c9b0a3db1dbc96b85537e8e096a586713 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 2 Dec 2025 17:22:31 +0000 Subject: [PATCH 5/9] improve the err msg Signed-off-by: Xiongfei Wei --- tpu_inference/layers/vllm/quantization/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tpu_inference/layers/vllm/quantization/common.py b/tpu_inference/layers/vllm/quantization/common.py index bc248603f..0324fb6be 100644 --- a/tpu_inference/layers/vllm/quantization/common.py +++ b/tpu_inference/layers/vllm/quantization/common.py @@ -1,7 +1,6 @@ import torchax from jax.sharding import Mesh, PartitionSpec from vllm.config import VllmConfig -from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -13,6 +12,7 @@ from tpu_inference.layers.vllm.linear_common import \ get_model_matmul_fusion_assignment +from tpu_inference.logger import init_logger from tpu_inference.utils import TPU_SECOND_LAST_MINOR # yapf: enable @@ -74,12 +74,12 @@ def get_input_sharding(self, x: torchax.tensor.Tensor): # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: logger.info( - f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return input sharding {self.input_sharding}." + f"SP is enabled, returning non-None input sharding {self.input_sharding}. token_num // self.mesh.shape[model] >= TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return self.input_sharding else: logger.info( - "SP is enabled, but token_num // self.mesh.shape[\"model\"] < TPU_SECOND_LAST_MINOR, return input sharding None." + f"SP is enabled, returning input sharding None. token_num // self.mesh.shape[model] < TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return None return self.input_sharding @@ -90,12 +90,12 @@ def get_output_sharding(self, x: torchax.tensor.Tensor): # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: logger.info( - f"SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding {self.output_sharding}." + f"SP is enabled, returning non-None output sharding {self.output_sharding}. token_num // self.mesh.shape[model] >= TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return self.output_sharding else: logger.info( - "SP is enabled, but token_num // self.mesh.shape[\"model\"] >= TPU_SECOND_LAST_MINOR, return output sharding None." + f"SP is enabled, return output sharding None. token_num // self.mesh.shape[model] < TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return None return self.output_sharding From 7afe8fef35b3617beaf8eca08a60ca43c03895b5 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Tue, 2 Dec 2025 22:32:43 +0000 Subject: [PATCH 6/9] fix up Signed-off-by: Xiongfei Wei --- .buildkite/parallelism/SP.yml | 3 ++- .buildkite/pipeline_jax.yml | 13 ------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/.buildkite/parallelism/SP.yml b/.buildkite/parallelism/SP.yml index d86fd4cf0..5474147e1 100644 --- a/.buildkite/parallelism/SP.yml +++ b/.buildkite/parallelism/SP.yml @@ -8,7 +8,8 @@ steps: queue: tpu_v6e_queue commands: - | - buildkite-agent meta-data set "SP_CorrectnessTest" "to be added" + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v /workspace/tpu_inference/tests/e2e/test_sequence_parallelism.py - label: "Record correctness test result for SP" key: "record_SP_CorrectnessTest" depends_on: "SP_CorrectnessTest" diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index fbcf4933e..42032735b 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -262,19 +262,6 @@ steps: - | .buildkite/scripts/run_in_docker.sh \ bash -c 'python3 -m pytest -s -v -x /workspace/tpu_inference/tests/lora/test_layers.py' - - - label: "E2E sequence parallelism test" - key: test_17 - soft_fail: true - env: - VLLM_LOG_LEVEL: "INFO" - agents: - queue: tpu_v6e_8_queue - commands: - - | - .buildkite/scripts/run_in_docker.sh \ - bash -c 'pytest -s -vv tests/e2e/test_sequence_parallelism.py' - # ----------------------------------------------------------------- # NOTIFICATION STEP # ----------------------------------------------------------------- From 59899b0a41d71bf8f14d2bb34146a6393470a0f8 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 3 Dec 2025 00:01:45 +0000 Subject: [PATCH 7/9] try to make the test capture the log Signed-off-by: Xiongfei Wei --- tests/e2e/test_sequence_parallelism.py | 67 +++++++++++++++++--------- 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py index 07c6412da..21779c78e 100644 --- a/tests/e2e/test_sequence_parallelism.py +++ b/tests/e2e/test_sequence_parallelism.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import io +import logging import os import time from dataclasses import asdict @@ -13,7 +15,7 @@ @pytest.fixture(autouse=True) def setup_new_model_design(): os.environ['MODEL_IMPL_TYPE'] = 'vllm' - # os.environ['SKIP_JAX_PRECOMPILE'] = '1' + os.environ['SKIP_JAX_PRECOMPILE'] = '1' @pytest.fixture @@ -87,30 +89,51 @@ def _run_inference_with_config(model_name: str, def test_model_sequence_parallelism( + caplog: pytest.LogCaptureFixture, test_prompts: list, sampling_params: SamplingParams, ): - # Use Llama 1B for this test - test_model = "Qwen/Qwen2.5-32B" - - # Test with data parallelism enabled - outputs = _run_inference_with_config( - model_name=test_model, - test_prompts=test_prompts, - sampling_params=sampling_params, - tensor_parallel_size=8, - async_scheduling=True, - ) + logger_name = "vllm.tpu_inference.layers.vllm.quantization.common" + logger = logging.getLogger(logger_name) + original_level = logger.level + original_propagate = logger.propagate - # Verify we got outputs for all prompts - assert len(outputs) == len(test_prompts) + # Create an in-memory stream to capture log output + log_capture_string = io.StringIO() + # Create a handler that writes to our in-memory stream + capture_handler = logging.StreamHandler(log_capture_string) - # Verify each output has generated text - for output in outputs: - assert len(output.outputs) > 0 - assert len(output.outputs[0].text.strip()) > 0 - print(f"Output: {output.outputs[0].text.strip()}") + test_model = "Qwen/Qwen2.5-32B" - print( - f"✓ Model sequence parallelism test passed with {len(outputs)} outputs" - ) + try: + logger.setLevel(logging.DEBUG) + logger.propagate = False + logger.addHandler(capture_handler) + + outputs = _run_inference_with_config( + model_name=test_model, + test_prompts=test_prompts, + sampling_params=sampling_params, + tensor_parallel_size=8, + async_scheduling=True, + ) + + # Verify we got outputs for all prompts + assert len(outputs) == len(test_prompts) + + # Verify each output has generated text + for output in outputs: + assert len(output.outputs) > 0 + assert len(output.outputs[0].text.strip()) > 0 + print(f"Output: {output.outputs[0].text.strip()}") + + log_contents = log_capture_string.getvalue() + print(f'xw32 {log_contents=}') + print( + f"✓ Model sequence parallelism test passed with {len(outputs)} outputs" + ) + finally: + # --- IMPORTANT: Clean up and restore the logger's original state --- + logger.setLevel(original_level) + logger.propagate = original_propagate + logger.removeHandler(capture_handler) From dab3ec6b6543bab5612705e1b9f81ea032faf36b Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 3 Dec 2025 01:50:43 +0000 Subject: [PATCH 8/9] still couldnt capture log. consider revert this and the last commit. Signed-off-by: Xiongfei Wei --- tests/e2e/test_sequence_parallelism.py | 50 ++++++++++++++------------ 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/tests/e2e/test_sequence_parallelism.py b/tests/e2e/test_sequence_parallelism.py index 21779c78e..e9f922ff3 100644 --- a/tests/e2e/test_sequence_parallelism.py +++ b/tests/e2e/test_sequence_parallelism.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import io import logging import os import time @@ -88,28 +87,38 @@ def _run_inference_with_config(model_name: str, time.sleep(5) +@pytest.fixture() +def temporary_enable_log_propagate(): + import logging + + logger = logging.getLogger("vllm") + logger.propagate = True + yield + logger.propagate = False + + +@pytest.fixture() +def caplog_vllm(temporary_enable_log_propagate, caplog): + # To capture vllm log, we should enable propagate=True temporarily + # because caplog depends on logs propagated to the root logger. + yield caplog + + def test_model_sequence_parallelism( - caplog: pytest.LogCaptureFixture, + caplog_vllm: pytest.LogCaptureFixture, test_prompts: list, sampling_params: SamplingParams, ): - logger_name = "vllm.tpu_inference.layers.vllm.quantization.common" - logger = logging.getLogger(logger_name) - original_level = logger.level - original_propagate = logger.propagate - - # Create an in-memory stream to capture log output - log_capture_string = io.StringIO() - # Create a handler that writes to our in-memory stream - capture_handler = logging.StreamHandler(log_capture_string) - + # Use Llama 1B for this test test_model = "Qwen/Qwen2.5-32B" + caplog_vllm.clear() - try: - logger.setLevel(logging.DEBUG) - logger.propagate = False - logger.addHandler(capture_handler) + # Set the logging level for the test + with caplog_vllm.at_level( + logging.INFO, + logger="vllm.tpu_inference.layers.vllm.quantization.common"): + # Test with data parallelism enabled outputs = _run_inference_with_config( model_name=test_model, test_prompts=test_prompts, @@ -127,13 +136,8 @@ def test_model_sequence_parallelism( assert len(output.outputs[0].text.strip()) > 0 print(f"Output: {output.outputs[0].text.strip()}") - log_contents = log_capture_string.getvalue() - print(f'xw32 {log_contents=}') + # caplog.text contains all the captured log output + print(f'xw32 {caplog_vllm.records[0].getMessage()}') print( f"✓ Model sequence parallelism test passed with {len(outputs)} outputs" ) - finally: - # --- IMPORTANT: Clean up and restore the logger's original state --- - logger.setLevel(original_level) - logger.propagate = original_propagate - logger.removeHandler(capture_handler) From decf2123baf90cc208aa783a769d8eebc21e4d45 Mon Sep 17 00:00:00 2001 From: Xiongfei Wei Date: Wed, 3 Dec 2025 18:49:37 +0000 Subject: [PATCH 9/9] change logger.info to logger.debug Signed-off-by: Xiongfei Wei --- tpu_inference/layers/vllm/quantization/common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tpu_inference/layers/vllm/quantization/common.py b/tpu_inference/layers/vllm/quantization/common.py index 0324fb6be..20459ff77 100644 --- a/tpu_inference/layers/vllm/quantization/common.py +++ b/tpu_inference/layers/vllm/quantization/common.py @@ -73,12 +73,12 @@ def get_input_sharding(self, x: torchax.tensor.Tensor): token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: - logger.info( + logger.debug( f"SP is enabled, returning non-None input sharding {self.input_sharding}. token_num // self.mesh.shape[model] >= TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return self.input_sharding else: - logger.info( + logger.debug( f"SP is enabled, returning input sharding None. token_num // self.mesh.shape[model] < TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return None @@ -89,12 +89,12 @@ def get_output_sharding(self, x: torchax.tensor.Tensor): token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: - logger.info( + logger.debug( f"SP is enabled, returning non-None output sharding {self.output_sharding}. token_num // self.mesh.shape[model] >= TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return self.output_sharding else: - logger.info( + logger.debug( f"SP is enabled, return output sharding None. token_num // self.mesh.shape[model] < TPU_SECOND_LAST_MINOR. {token_num=}, {self.mesh.shape['model']=}, {TPU_SECOND_LAST_MINOR=}." ) return None