Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/models/common/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,43 @@ def test_get_model_not_implemented(self, mock_get_flax, mock_get_vllm,

mock_get_flax.assert_not_called()
mock_get_vllm.assert_not_called()

@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
@patch("tpu_inference.models.common.model_loader.get_flax_model")
def test_get_model_auto_resolves_to_flax_nnx(self, mock_get_flax,
mock_get_vllm, vllm_config,
rng, mesh):
"""
Tests that 'auto' resolves to 'flax_nnx' for standard architectures
(not in _VLLM_REQUIRED_ARCHITECTURES).
"""
# vllm_config uses Qwen3 which is NOT in _VLLM_REQUIRED_ARCHITECTURES
mock_get_flax.return_value = "flax_model_sentinel"

result = model_loader.get_model(vllm_config, rng, mesh)

mock_get_flax.assert_called_once_with(vllm_config, rng, mesh, False)
mock_get_vllm.assert_not_called()
assert result == "flax_model_sentinel"

@patch.dict(os.environ, {"MODEL_IMPL_TYPE": "auto"}, clear=True)
@patch("tpu_inference.models.common.model_loader.get_vllm_model")
@patch("tpu_inference.models.common.model_loader.get_flax_model")
def test_get_model_auto_resolves_to_vllm_for_gpt_oss(
self, mock_get_flax, mock_get_vllm, vllm_config, rng, mesh):
"""
Tests that 'auto' resolves to 'vllm' for architectures in
_VLLM_REQUIRED_ARCHITECTURES (e.g., GptOssForCausalLM).
"""
# Mock the architecture to be GptOssForCausalLM
vllm_config.model_config.hf_config.architectures = [
"GptOssForCausalLM"
]
mock_get_vllm.return_value = "vllm_model_sentinel"

result = model_loader.get_model(vllm_config, rng, mesh)

mock_get_flax.assert_not_called()
mock_get_vllm.assert_called_once_with(vllm_config, rng, mesh)
assert result == "vllm_model_sentinel"
2 changes: 1 addition & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_disaggregated_serving_env_vars(monkeypatch: pytest.MonkeyPatch):

def test_model_impl_type_default(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delenv("MODEL_IMPL_TYPE", raising=False)
assert envs.MODEL_IMPL_TYPE == "flax_nnx"
assert envs.MODEL_IMPL_TYPE == "auto"


def test_cache_preserves_values_across_env_changes(
Expand Down
6 changes: 3 additions & 3 deletions tpu_inference/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
DECODE_SLICES: str = ""
SKIP_JAX_PRECOMPILE: bool = False
VLLM_XLA_CHECK_RECOMPILATION: bool = False
MODEL_IMPL_TYPE: str = "flax_nnx"
MODEL_IMPL_TYPE: str = "auto"
NEW_MODEL_DESIGN: bool = False
PHASED_PROFILING_DIR: str = ""
PYTHON_TRACER_LEVEL: int = 1
Expand Down Expand Up @@ -127,8 +127,8 @@ def _get_bool_env() -> bool:
env_bool("VLLM_XLA_CHECK_RECOMPILATION", default=False),
# Model implementation type (e.g., "flax_nnx")
"MODEL_IMPL_TYPE":
env_with_choices("MODEL_IMPL_TYPE", "flax_nnx",
["vllm", "flax_nnx", "jetpack"]),
env_with_choices("MODEL_IMPL_TYPE", "auto",
["auto", "vllm", "flax_nnx", "jetpack"]),
# Enable new experimental model design
"NEW_MODEL_DESIGN":
env_bool("NEW_MODEL_DESIGN", default=False),
Expand Down
53 changes: 36 additions & 17 deletions tpu_inference/models/common/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ def get_vllm_model(
return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model


# Architectures that require "vllm" implementation type when MODEL_IMPL_TYPE is "auto".
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"require" might be too strong word. replace it with "prefer"

# These architectures are listed here because they have better performance with the
# vLLM PyTorch backend compared to the flax_nnx JAX backend for now.
_VLLM_REQUIRED_ARCHITECTURES: frozenset[str] = frozenset({"GptOssForCausalLM"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, this kind of constants should be placed at the start of the file. Please move it.



def get_model(
vllm_config: VllmConfig,
rng: jax.Array,
Expand All @@ -342,24 +348,37 @@ def get_model(
impl = envs.MODEL_IMPL_TYPE
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")

if impl == "flax_nnx":
try:
# Try to load the flax model first
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
except UnsupportedArchitectureError as e:
# Convert the error message to a string to check its contents
error_msg = str(e)

logger.warning(error_msg)

# Fall back to the vLLM model and updating the dtype accordingly
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)
if impl == "auto":
# Resolve "auto" based on architecture
architectures = getattr(vllm_config.model_config.hf_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question: is there a cases where there's a multiple "architectures" for a single model?

"architectures", [])
for arch in architectures:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to above comment. can we just to an assert to check if len(architectures)==1 and do a simple hash map fetch instead of iterating for loop?

if arch in _VLLM_REQUIRED_ARCHITECTURES:
impl = "vllm"
break
else:
impl = "flax_nnx"
logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'")

match impl:
case "flax_nnx":
try:
# Try to load the flax model first
return get_flax_model(vllm_config, rng, mesh, is_draft_model)
except UnsupportedArchitectureError as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a nit question: in c's switch statements, if we don't put break;, it will automatically invoke next case. Is it not the case for python's match/case? I.e., if UnsupportedArchitectureError is thrown, we skip break; statement and automatically let the next case (which is case "vllm") to be invoke.

# Convert the error message to a string to check its contents
error_msg = str(e)

logger.warning(error_msg)

# Fall back to the vLLM model and updating the dtype accordingly
vllm_config.model_config.dtype = j2t_dtype(
vllm_config.model_config.dtype.dtype)
return get_vllm_model(vllm_config, rng, mesh)
case "vllm":
return get_vllm_model(vllm_config, rng, mesh)
elif impl == "vllm":
return get_vllm_model(vllm_config, rng, mesh)
else:
raise NotImplementedError("Unsupported MODEL_IMPL_TYPE")
case _:
raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}")


def _validate_model_interface(model: Any) -> None:
Expand Down
3 changes: 1 addition & 2 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import functools
import os
import random
from contextlib import nullcontext
from dataclasses import dataclass
Expand Down Expand Up @@ -1719,7 +1718,7 @@ def _sync_weights(
shard=shard)

def get_intermediate_tensor_spec(self, num_tokens: int):
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
impl = envs.MODEL_IMPL_TYPE
jax_dtype = t2j_dtype(self.dtype) if impl == "vllm" else self.dtype
num_padded_tokens = runner_utils.get_padded_token_len(
self.num_tokens_paddings, num_tokens)
Expand Down