Skip to content

Commit ea0cb49

Browse files
DarkLight1337devpatelio
authored andcommitted
[Bugfix] Use HF config fields as fallback when loading Mistral config (vllm-project#29239)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 4ac5372 commit ea0cb49

File tree

4 files changed

+25
-4
lines changed

4 files changed

+25
-4
lines changed

.buildkite/test-amd.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,7 @@ steps:
754754
torch_nightly: true
755755
source_file_dependencies:
756756
- vllm/model_executor/models/
757+
- vllm/transformers_utils/
757758
- tests/models/test_initialization.py
758759
commands:
759760
# Only when vLLM model source is modified - test initialization of a large

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ steps:
691691
torch_nightly: true
692692
source_file_dependencies:
693693
- vllm/model_executor/models/
694+
- vllm/transformers_utils/
694695
- tests/models/test_initialization.py
695696
commands:
696697
# Only when vLLM model source is modified - test initialization of a large

vllm/transformers_utils/config.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,19 @@ def parse(
204204

205205
from vllm.transformers_utils.configs.mistral import adapt_config_dict
206206

207-
config = adapt_config_dict(config_dict)
207+
# Get missing fields from HF config if available
208+
try:
209+
hf_config_dict, _ = PretrainedConfig.get_config_dict(
210+
model,
211+
revision=revision,
212+
code_revision=code_revision,
213+
token=_get_hf_token(),
214+
**kwargs,
215+
)
216+
except OSError: # Not found
217+
hf_config_dict = {}
218+
219+
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
208220

209221
# Mistral configs may define sliding_window as list[int]. Convert it
210222
# to int and add the layer_types list[str] to make it HF compatible

vllm/transformers_utils/configs/mistral.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@
99
logger = init_logger(__name__)
1010

1111

12-
def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
13-
config_dict.update(kwargs)
12+
def adapt_config_dict(
13+
config_dict: dict[str, Any],
14+
defaults: dict[str, Any],
15+
) -> PretrainedConfig:
1416
config_dict = _remap_general_mistral_args(config_dict)
1517

1618
if bool(config_dict.get("quantization")):
1719
config_dict = _remap_mistral_quantization_args(config_dict)
1820

19-
if bool(config_dict.get("moe")):
21+
if config_dict.get("model_type") == "mamba":
22+
config_dict["architectures"] = ["Mamba2ForCausalLM"]
23+
elif bool(config_dict.get("moe")):
2024
config_dict["architectures"] = ["MixtralForCausalLM"]
2125
else:
2226
config_dict["architectures"] = ["MistralForCausalLM"]
@@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
5256
if is_audio:
5357
config_dict = _remap_mistral_audio_args(config_dict)
5458

59+
for k, v in defaults.items():
60+
config_dict.setdefault(k, v)
61+
5562
config = PretrainedConfig.from_dict(config_dict)
5663

5764
logger.debug("Initialized config %s", config)

0 commit comments

Comments
 (0)