Skip to content

Commit c5a05c1

Browse files
authored
Refactor scheduler and engine.py (#4163)
* refactor sequence states * fix pd, better property * skip decoding warmup * rename * add more profile logs * add config builder * add engine_loop and input_maker * fix close * fix ut * update end session * fix profiler * typo fix
1 parent a11f736 commit c5a05c1

File tree

22 files changed

+1609
-1233
lines changed

22 files changed

+1609
-1233
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import copy
3+
import os
4+
5+
from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig
6+
from lmdeploy.pytorch.config import (BackendConfig, CacheConfig, DistConfig, MiscConfig, SchedulerConfig,
7+
SpecDecodeConfig)
8+
from lmdeploy.utils import get_logger, get_max_batch_size, get_model
9+
10+
11+
class ConfigBuilder:
12+
13+
@staticmethod
14+
def update_engine_config(engine_config: PytorchEngineConfig):
15+
"""Update pytorch engine config."""
16+
logger = get_logger('lmdeploy')
17+
18+
# make sure engine exits
19+
if engine_config is None:
20+
engine_config = PytorchEngineConfig()
21+
else:
22+
engine_config = copy.deepcopy(engine_config)
23+
24+
if engine_config.max_batch_size is None:
25+
engine_config.max_batch_size = get_max_batch_size(engine_config.device_type)
26+
27+
if engine_config.dllm_block_length is not None:
28+
max_prefill_token_num = engine_config.max_prefill_token_num
29+
max_batch_size = engine_config.max_batch_size
30+
if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num:
31+
engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length
32+
logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} '
33+
f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size '
34+
f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).')
35+
36+
if engine_config.dp != 1:
37+
if engine_config.tp == 1 and engine_config.ep == 1:
38+
logger.warning('Data parallelism is enabled but tensor parallelism and '
39+
'expert parallelism are not enabled. Setting dp=1.')
40+
engine_config.dp = 1
41+
engine_config.dp_rank = 0
42+
43+
return engine_config
44+
45+
@staticmethod
46+
def build_scheduler_config(engine_config: PytorchEngineConfig):
47+
"""Build scheduler config."""
48+
scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size,
49+
max_session_len=engine_config.session_len,
50+
prefill_interval=engine_config.prefill_interval)
51+
return scheduler_config
52+
53+
@staticmethod
54+
def build_cache_config(engine_config: PytorchEngineConfig):
55+
"""Build cache config."""
56+
cache_config = CacheConfig(
57+
max_batches=engine_config.max_batch_size,
58+
block_size=engine_config.block_size,
59+
num_cpu_blocks=engine_config.num_cpu_blocks,
60+
num_gpu_blocks=engine_config.num_gpu_blocks,
61+
cache_max_entry_count=engine_config.cache_max_entry_count,
62+
max_prefill_token_num=engine_config.max_prefill_token_num,
63+
enable_prefix_caching=engine_config.enable_prefix_caching,
64+
quant_policy=engine_config.quant_policy,
65+
device_type=engine_config.device_type,
66+
migration_backend=engine_config.migration_backend,
67+
role=engine_config.role,
68+
# reserve 1 blocks for dummy input and padding
69+
num_reserved_gpu_blocks=1)
70+
return cache_config
71+
72+
@staticmethod
73+
def build_backend_config(engine_config: PytorchEngineConfig):
74+
"""Build backend config."""
75+
backend_config = BackendConfig(
76+
eager_mode=engine_config.eager_mode,
77+
device_type=engine_config.device_type,
78+
)
79+
return backend_config
80+
81+
@staticmethod
82+
def build_dist_config(engine_config: PytorchEngineConfig):
83+
"""Build dist config."""
84+
dist_config = DistConfig.from_engine_config(engine_config=engine_config)
85+
return dist_config
86+
87+
@staticmethod
88+
def build_misc_config(engine_config: PytorchEngineConfig):
89+
"""Build misc config."""
90+
misc_config = MiscConfig.from_engine_config(engine_config)
91+
return misc_config
92+
93+
@staticmethod
94+
def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig,
95+
cache_config: CacheConfig):
96+
"""Build spec decode config."""
97+
specdecode_config = None
98+
if speculative_config is not None:
99+
draft_model = speculative_config.model
100+
if draft_model and not os.path.exists(speculative_config.model):
101+
draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision)
102+
103+
specdecode_config = SpecDecodeConfig.from_config(
104+
method=speculative_config.method,
105+
num_speculative_tokens=speculative_config.num_speculative_tokens,
106+
model=draft_model,
107+
target_model=target_model,
108+
target_cache_cfg=cache_config,
109+
dtype=engine_config.dtype,
110+
)
111+
return specdecode_config

0 commit comments

Comments
 (0)