diff --git a/tests/models/common/test_model_loader.py b/tests/models/common/test_model_loader.py index 2339cefea..3b8d865ef 100644 --- a/tests/models/common/test_model_loader.py +++ b/tests/models/common/test_model_loader.py @@ -381,3 +381,43 @@ def test_get_model_not_implemented(self, mock_get_flax, mock_get_vllm, mock_get_flax.assert_not_called() mock_get_vllm.assert_not_called() + + @patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True) + @patch("tpu_inference.models.common.model_loader.get_vllm_model") + @patch("tpu_inference.models.common.model_loader.get_flax_model") + def test_get_model_auto_resolves_to_flax_nnx(self, mock_get_flax, + mock_get_vllm, vllm_config, + rng, mesh): + """ + Tests that 'auto' resolves to 'flax_nnx' for standard architectures + (not in _VLLM_REQUIRED_ARCHITECTURES). + """ + # vllm_config uses Qwen3 which is NOT in _VLLM_REQUIRED_ARCHITECTURES + mock_get_flax.return_value = "flax_model_sentinel" + + result = model_loader.get_model(vllm_config, rng, mesh) + + mock_get_flax.assert_called_once_with(vllm_config, rng, mesh, False) + mock_get_vllm.assert_not_called() + assert result == "flax_model_sentinel" + + @patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True) + @patch("tpu_inference.models.common.model_loader.get_vllm_model") + @patch("tpu_inference.models.common.model_loader.get_flax_model") + def test_get_model_auto_resolves_to_vllm_for_gpt_oss( + self, mock_get_flax, mock_get_vllm, vllm_config, rng, mesh): + """ + Tests that 'auto' resolves to 'vllm' for architectures in + _VLLM_REQUIRED_ARCHITECTURES (e.g., GptOssForCausalLM). + """ + # Mock the architecture to be GptOssForCausalLM + vllm_config.model_config.hf_config.architectures = [ + "GptOssForCausalLM" + ] + mock_get_vllm.return_value = "vllm_model_sentinel" + + result = model_loader.get_model(vllm_config, rng, mesh) + + mock_get_flax.assert_not_called() + mock_get_vllm.assert_called_once_with(vllm_config, rng, mesh) + assert result == "vllm_model_sentinel" diff --git a/tests/test_envs.py b/tests/test_envs.py index e97d36fc5..7149160fb 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -256,7 +256,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch): def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch): monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False) - assert envs.MODEL_IMPL_TYPE == "flax_nnx" + assert envs.MODEL_IMPL_TYPE == "auto" def test_cache_preserves_values_across_env_changes( diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index b7d3bad5c..f41e0798b 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -16,7 +16,7 @@ DECODE_SLICES: str = "" SKIP_JAX_PRECOMPILE: bool = False VLLM_XLA_CHECK_RECOMPILATION: bool = False - MODEL_IMPL_TYPE: str = "flax_nnx" + MODEL_IMPL_TYPE: str = "auto" NEW_MODEL_DESIGN: bool = False PHASED_PROFILING_DIR: str = "" PYTHON_TRACER_LEVEL: int = 1 @@ -128,8 +128,8 @@ def _get_bool_env() -> bool: env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False), # Model implementation type (e.g., "flax_nnx") "MODEL_IMPL_TYPE": - env_with_choices("MODEL_IMPL_TYPE", "flax_nnx", - ["vllm", "flax_nnx", "jetpack"]), + env_with_choices("MODEL_IMPL_TYPE", "auto", + ["auto", "vllm", "flax_nnx", "jetpack"]), # Enable new experimental model design "NEW_MODEL_DESIGN": env_bool("NEW_MODEL_DESIGN", default=False), diff --git a/tpu_inference/models/common/model_loader.py b/tpu_inference/models/common/model_loader.py index 73a761dea..fb035bde1 100644 --- a/tpu_inference/models/common/model_loader.py +++ b/tpu_inference/models/common/model_loader.py @@ -24,6 +24,12 @@ _MODEL_REGISTRY = {} +# Architectures that prefer "vllm" implementation type when MODEL_IMPL_TYPE is "auto". +# These architectures are listed here because they have better performance with the +# vLLM PyTorch backend compared to the flax_nnx JAX backend for now. +_VLLM_PREFERRED_ARCHITECTURES: frozenset[str] = frozenset( + {"GptOssForCausalLM"}) + class UnsupportedArchitectureError(ValueError): """Raised when a model architecture is not supported in the registry.""" @@ -342,24 +348,36 @@ def get_model( impl = envs.MODEL_IMPL_TYPE logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}") - if impl == "flax_nnx": - try: - # Try to load the flax model first - return get_flax_model(vllm_config, rng, mesh, is_draft_model) - except UnsupportedArchitectureError as e: - # Convert the error message to a string to check its contents - error_msg = str(e) - - logger.warning(error_msg) - - # Fall back to the vLLM model and updating the dtype accordingly - vllm_config.model_config.dtype = j2t_dtype( - vllm_config.model_config.dtype.dtype) + if impl == "auto": + # Resolve "auto" based on architecture + architectures = getattr(vllm_config.model_config.hf_config, + "architectures", []) + assert len(architectures) == 1, ( + f"Expected exactly one architecture, got {len(architectures)}: " + f"{architectures}") + arch = architectures[0] + impl = "vllm" if arch in _VLLM_PREFERRED_ARCHITECTURES else "flax_nnx" + logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'") + + match impl: + case "flax_nnx": + try: + # Try to load the flax model first + return get_flax_model(vllm_config, rng, mesh, is_draft_model) + except UnsupportedArchitectureError as e: + # Convert the error message to a string to check its contents + error_msg = str(e) + + logger.warning(error_msg) + + # Fall back to the vLLM model and updating the dtype accordingly + vllm_config.model_config.dtype = j2t_dtype( + vllm_config.model_config.dtype.dtype) + return get_vllm_model(vllm_config, rng, mesh) + case "vllm": return get_vllm_model(vllm_config, rng, mesh) - elif impl == "vllm": - return get_vllm_model(vllm_config, rng, mesh) - else: - raise NotImplementedError("Unsupported MODEL_IMPL_TYPE") + case _: + raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}") def _validate_model_interface(model: Any) -> None: diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index f100b4fd0..b2fd73ae6 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -1,6 +1,5 @@ import copy import functools -import os import random from contextlib import nullcontext from dataclasses import dataclass @@ -1719,7 +1718,7 @@ def _sync_weights( shard=shard) def get_intermediate_tensor_spec(self, num_tokens: int): - impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower() + impl = envs.MODEL_IMPL_TYPE jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype num_padded_tokens = runner_utils.get_padded_token_len( self.num_tokens_paddings, num_tokens)