-
Notifications
You must be signed in to change notification settings - Fork 54
Add default 'auto' MODEL_IMPL_TYPE that resolves based on architecture #1255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@kyuyeunk Please review. |
kyuyeunk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be possible move 'auto' to 'match/case' as well?
- Add 'auto' as default value for MODEL_IMPL_TYPE env var - For GptOssForCausalLM, 'auto' resolves to 'vllm' for better performance - For all other architectures, 'auto' resolves to 'flax_nnx' - Add _VLLM_REQUIRED_ARCHITECTURES frozenset in model_loader.py - Use match/case pattern in get_model() for implementation selection - Add tests for 'auto' resolution behavior Signed-off-by: Xing Liu <xingliu14@gmail.com>
|
It is possible to move it in to match-case, but in that case it will have duplicated codes, including: get_vllm_model, get_flax_model and the fall back check. I think resolve first then use the same code will be more clean. |
| return jit_model, compute_logits_fn, combine_hidden_states_fn, None, params, lora_manager, model | ||
|
|
||
|
|
||
| # Architectures that require "vllm" implementation type when MODEL_IMPL_TYPE is "auto". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"require" might be too strong word. replace it with "prefer"
| # Architectures that require "vllm" implementation type when MODEL_IMPL_TYPE is "auto". | ||
| # These architectures are listed here because they have better performance with the | ||
| # vLLM PyTorch backend compared to the flax_nnx JAX backend for now. | ||
| _VLLM_REQUIRED_ARCHITECTURES: frozenset[str] = frozenset({"GptOssForCausalLM"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, this kind of constants should be placed at the start of the file. Please move it.
| vllm_config.model_config.dtype.dtype) | ||
| if impl == "auto": | ||
| # Resolve "auto" based on architecture | ||
| architectures = getattr(vllm_config.model_config.hf_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dumb question: is there a cases where there's a multiple "architectures" for a single model?
| # Resolve "auto" based on architecture | ||
| architectures = getattr(vllm_config.model_config.hf_config, | ||
| "architectures", []) | ||
| for arch in architectures: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar to above comment. can we just to an assert to check if len(architectures)==1 and do a simple hash map fetch instead of iterating for loop?
| try: | ||
| # Try to load the flax model first | ||
| return get_flax_model(vllm_config, rng, mesh, is_draft_model) | ||
| except UnsupportedArchitectureError as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably a nit question: in c's switch statements, if we don't put break;, it will automatically invoke next case. Is it not the case for python's match/case? I.e., if UnsupportedArchitectureError is thrown, we skip break; statement and automatically let the next case (which is case "vllm") to be invoke.
Description
autoas default value for MODEL_IMPL_TYPE env varautoresolves tovllmfor better performanceflax_nnxfor better performanceTests
pytest
Checklist
Before submitting this PR, please make sure: