-
Notifications
You must be signed in to change notification settings - Fork 55
Add default 'auto' MODEL_IMPL_TYPE that resolves based on architecture #1255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,6 +333,12 @@ def get_vllm_model( | |
| return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model | ||
|
|
||
|
|
||
| # Architectures that require "vllm" implementation type when MODEL_IMPL_TYPE is "auto". | ||
| # These architectures are listed here because they have better performance with the | ||
| # vLLM PyTorch backend compared to the flax_nnx JAX backend for now. | ||
| _VLLM_REQUIRED_ARCHITECTURES: frozenset[str] = frozenset({"GptOssForCausalLM"}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, this kind of constants should be placed at the start of the file. Please move it. |
||
|
|
||
|
|
||
| def get_model( | ||
| vllm_config: VllmConfig, | ||
| rng: jax.Array, | ||
|
|
@@ -342,24 +348,37 @@ def get_model( | |
| impl = envs.MODEL_IMPL_TYPE | ||
| logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}") | ||
|
|
||
| if impl == "flax_nnx": | ||
| try: | ||
| # Try to load the flax model first | ||
| return get_flax_model(vllm_config, rng, mesh, is_draft_model) | ||
| except UnsupportedArchitectureError as e: | ||
| # Convert the error message to a string to check its contents | ||
| error_msg = str(e) | ||
|
|
||
| logger.warning(error_msg) | ||
|
|
||
| # Fall back to the vLLM model and updating the dtype accordingly | ||
| vllm_config.model_config.dtype = j2t_dtype( | ||
| vllm_config.model_config.dtype.dtype) | ||
| if impl == "auto": | ||
| # Resolve "auto" based on architecture | ||
| architectures = getattr(vllm_config.model_config.hf_config, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dumb question: is there a cases where there's a multiple "architectures" for a single model? |
||
| "architectures", []) | ||
| for arch in architectures: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar to above comment. can we just to an assert to check if |
||
| if arch in _VLLM_REQUIRED_ARCHITECTURES: | ||
| impl = "vllm" | ||
| break | ||
| else: | ||
| impl = "flax_nnx" | ||
| logger.info(f"Resolved MODEL_IMPL_TYPE 'auto' to '{impl}'") | ||
|
|
||
| match impl: | ||
| case "flax_nnx": | ||
| try: | ||
| # Try to load the flax model first | ||
| return get_flax_model(vllm_config, rng, mesh, is_draft_model) | ||
| except UnsupportedArchitectureError as e: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably a nit question: in c's switch statements, if we don't put |
||
| # Convert the error message to a string to check its contents | ||
| error_msg = str(e) | ||
|
|
||
| logger.warning(error_msg) | ||
|
|
||
| # Fall back to the vLLM model and updating the dtype accordingly | ||
| vllm_config.model_config.dtype = j2t_dtype( | ||
| vllm_config.model_config.dtype.dtype) | ||
| return get_vllm_model(vllm_config, rng, mesh) | ||
| case "vllm": | ||
| return get_vllm_model(vllm_config, rng, mesh) | ||
| elif impl == "vllm": | ||
| return get_vllm_model(vllm_config, rng, mesh) | ||
| else: | ||
| raise NotImplementedError("Unsupported MODEL_IMPL_TYPE") | ||
| case _: | ||
| raise NotImplementedError(f"Unsupported MODEL_IMPL_TYPE: {impl}") | ||
|
|
||
|
|
||
| def _validate_model_interface(model: Any) -> None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"require" might be too strong word. replace it with "prefer"