Skip to content

Commit ea6e583

Browse files
committed
address comments
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent c17d46b commit ea6e583

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

vllm/v1/engine/core.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,9 +607,12 @@ def __init__(
607607

608608
self._init_data_parallel(vllm_config)
609609

610-
super().__init__(
611-
vllm_config, executor_class, log_stats, executor_fail_callback
612-
)
610+
from vllm.config import set_current_vllm_config
611+
612+
with set_current_vllm_config(vllm_config):
613+
super().__init__(
614+
vllm_config, executor_class, log_stats, executor_fail_callback
615+
)
613616

614617
# Background Threads and Queues for IO. These enable us to
615618
# overlap ZMQ socket IO with GPU since they release the GIL,

vllm/v1/worker/gpu_worker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,10 @@ def init_device(self):
269269
# to hijack tensor allocation.
270270
def load_model(self) -> None:
271271
eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1"
272-
with self._maybe_get_memory_pool_context(
273-
tag="weights"
274-
) and set_current_vllm_config(self.vllm_config):
272+
with (
273+
self._maybe_get_memory_pool_context(tag="weights"),
274+
set_current_vllm_config(self.vllm_config),
275+
):
275276
self.model_runner.load_model(eep_scale_up=eep_scale_up)
276277

277278
def update_config(self, overrides: dict[str, Any]) -> None:

0 commit comments

Comments
 (0)