From d96b324fc4ea46783b9efb8345218a278d328364 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 27 Nov 2025 18:27:16 +0800 Subject: [PATCH 01/12] refactor sequence states --- lmdeploy/pytorch/engine/engine.py | 45 ++-- lmdeploy/pytorch/messages.py | 57 +++-- .../paging/eviction_helper/__init__.py | 17 +- .../recompute_eviction_helper.py | 7 +- lmdeploy/pytorch/paging/scheduler.py | 207 +++++------------- .../pytorch/paging/seq_states/__init__.py | 2 + lmdeploy/pytorch/paging/seq_states/states.py | 166 ++++++++++++++ lmdeploy/pytorch/strategies/ar/sequence.py | 4 +- .../pytorch/strategies/ar_spec/sequence.py | 4 +- lmdeploy/pytorch/strategies/dllm/sequence.py | 4 +- tests/pytorch/paging/test_block_manager.py | 92 ++++++-- tests/pytorch/paging/test_block_trie.py | 55 +++-- tests/pytorch/paging/test_scheduler.py | 66 +++--- 13 files changed, 438 insertions(+), 288 deletions(-) create mode 100644 lmdeploy/pytorch/paging/seq_states/__init__.py create mode 100644 lmdeploy/pytorch/paging/seq_states/states.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index b343040873..609f8fa274 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.utils import get_logger, get_max_batch_size, get_model, logging_timer +from lmdeploy.utils import get_logger, get_max_batch_size, get_model from ..adapter.adapter import AdapterManager from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig, SpecDecodeConfig @@ -281,7 +281,7 @@ def do_prefill_dp(self): if self.next_is_prefill: ret = scheduler.has_waiting() else: - ret = not scheduler.has_running() + ret = not scheduler.has_ready() return ret def do_prefill_default(self): @@ -289,7 +289,7 @@ def do_prefill_default(self): scheduler = self.scheduler if not scheduler.has_waiting(): return False - num_running = scheduler.num_running() + num_ready = scheduler.num_ready() num_waiting = scheduler.num_waiting() max_batches = self.scheduler_config.max_batches # prefill if too much waiting @@ -297,7 +297,7 @@ def do_prefill_default(self): if num_waiting >= permitted_waiting: return True # prefill if no enough running - if num_running < max_batches * 0.5: + if num_ready < max_batches * 0.5: return True # decoding return False @@ -328,11 +328,11 @@ async def prefetch_next_inputs(self): if prefill: enable = True else: - num_running = scheduler.num_running() + num_ready = scheduler.num_ready() is_decoding = self.forward_inputs['inputs'].is_decoding running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding or self.spec_decoding else 0 - if num_running > running_threshold: + if num_ready > running_threshold: enable = True if enable: @@ -592,7 +592,7 @@ def _on_end_session(self, reqs: List[Request], **kwargs): if session_id in self.scheduler.sessions: msgs = list(self.scheduler.sessions[session_id].sequences.values()) if len(msgs) > 0 and msgs[0].preserve_cache: - self.scheduler._set_message_status(msgs[0], MessageStatus.TO_BE_MIGRATED) + msgs[0].state.finish() else: self.end_session(session_id) resp_type = ResponseType.SUCCESS @@ -676,9 +676,7 @@ def __update_max_new_tokens(msg): preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) __update_max_new_tokens(msg) - scheduler.add_sequence(msg) if migration_request: - self.scheduler._set_message_status(msg, MessageStatus.WAITING_MIGRATION) self.migration_event.set() else: msg = next(iter(sess.sequences.values())) @@ -689,7 +687,7 @@ def __update_max_new_tokens(msg): mode=UpdateTokenMode.INPUTS, ) msg.sampling_param = sampling_param - msg.status = MessageStatus.WAITING + msg.state.activate() __update_max_new_tokens(msg) msg.resp = req.resp @@ -775,7 +773,6 @@ def __has_values(input_multimodals): return vision_embedding_inputs @torch.inference_mode() - @logging_timer('CreateModelInputs', logger) @record_function('CreateModelInputs') def create_model_inputs(self, messages: SeqList, is_prefill: bool): """Create model inputs from messages. @@ -861,7 +858,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, if model_metas is None: model_metas = [None] * len(running) for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): - if msg.status != MessageStatus.MIGRATION_LOCKED: + if msg.status != MessageStatus.MIGRATION_RUNNING: continue update_token = token @@ -870,7 +867,7 @@ def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, if stop: update_token = _EMPTY_TOKEN msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) - msg.status = MessageStatus.STOPPED + msg.state.finish() @record_function('make_infer_outputs') def _make_infer_outputs( @@ -889,7 +886,7 @@ def _make_infer_outputs( logprobs.indices = logprobs.indices.tolist() seq_length = [seq.num_token_ids for seq in running] - is_run = [seq.status == MessageStatus.LOCKED for seq in running] + is_run = [seq.status == MessageStatus.RUNNING for seq in running] self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding) # generate output @@ -966,7 +963,7 @@ def __need_schedule_again(prefill: bool, scheduler_output): if (self.engine_config.role == EngineRole.Prefill): return False # disable decoding if no running reqs. - if not self.scheduler.has_running(): + if not self.scheduler.has_ready(): logger.warning('No running sequences for decoding scheduling after prefill scheduling.') return False return True @@ -1107,12 +1104,12 @@ async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionReque async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): """Async loop migration.""" while True: - migration_running = self.scheduler._schedule_migration() - if not migration_running and not self.scheduler.has_migration_waiting(): + migration_ready = self.scheduler._schedule_migration() + if not migration_ready and not self.scheduler.has_migration_waiting(): await self.migration_event.wait() - elif migration_running: + elif migration_ready: self.migration_event.clear() - for msg in migration_running: + for msg in migration_ready: migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] migration_request = msg.migration_request prefill_block_ids = migration_request.remote_block_ids @@ -1137,8 +1134,8 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event # generate output outputs: Dict[int, InferOutput] = dict() - self.scheduler.lock_running_migration(migration_running) - for _, msg in enumerate(migration_running): + self.scheduler.activate_migration_seqs(migration_ready) + for _, msg in enumerate(migration_ready): session_id = msg.session_id msg.resp.type = ResponseType.SUCCESS token_ids = [msg.migration_request.remote_token_id] @@ -1155,7 +1152,7 @@ async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event outputs[session_id] = out self.update_running_migration([msg], np.array([token_ids]), [False], [None]) resp_que.put_nowait(outputs) - self.scheduler.unlock_running_migration(migration_running) + self.scheduler.deactivate_migration_seqs(migration_ready) has_runable_event.set() else: # release coroutine for decoding @@ -1202,7 +1199,7 @@ async def _async_loop_main( is_decoding = forward_inputs['inputs'].is_decoding running = next_running next_running = None - scheduler.lock_running(running) + scheduler.active_seqs(running) for idx in range(num_loops): # pre-forward before get last token @@ -1221,7 +1218,7 @@ async def _async_loop_main( if idx == num_loops // 2: forward_event.clear() - scheduler.unlock_running(running) + scheduler.deactive_seqs(running) has_runable_event.set() @staticmethod diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 2e02fabf04..1fd26ea2cd 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -14,6 +14,8 @@ from .block import LogicalTokenBlocks if TYPE_CHECKING: + from lmdeploy.pytorch.paging.scheduler import Scheduler + from lmdeploy.pytorch.paging.seq_states.states import StateBase from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy logger = get_logger('lmdeploy') @@ -146,21 +148,19 @@ class MessageStatus(enum.Enum): """Status of a sequence.""" WAITING = enum.auto() - RUNNING = enum.auto() + READY = enum.auto() STOPPED = enum.auto() - ENDED = enum.auto() - ABORTED = enum.auto() - LOCKED = enum.auto() + RUNNING = enum.auto() # PD Disaggregation - # WAITING_MIGRATION: state of Unmigrated Requests + # MIGRATION_WAITING: state of Unmigrated Requests # in both prefill and decode engines are tagged by - # RUNNING_MIGRATION: state of Migrating Requests + # MIGRATION_READY: state of Migrating Requests # in decode engine TO_BE_MIGRATED = enum.auto() - WAITING_MIGRATION = enum.auto() - RUNNING_MIGRATION = enum.auto() - MIGRATION_LOCKED = enum.auto() + MIGRATION_WAITING = enum.auto() + MIGRATION_READY = enum.auto() + MIGRATION_RUNNING = enum.auto() MIGRATION_DONE = enum.auto() @@ -203,10 +203,9 @@ def num_sequences(self, status: MessageStatus): """Num sequences.""" return len(self.get_sequences(status)) - def add_sequence(self, seq: 'SchedulerSequence'): + def add_sequence(self, seq: 'SchedulerSequence', status: MessageStatus): """Add sequence.""" seq_id = seq.seq_id - status = seq.status status_map = self._status_seq_map[status] self._seq_map[seq_id] = seq status_map[seq_id] = seq @@ -247,12 +246,12 @@ def _to_ndarray(token_ids) -> np.ndarray: class SchedulerSession: """Scheduler session.""" - def __init__(self, session_id: int, seq_manager: SequenceManager) -> None: + def __init__(self, session_id: int, seq_manager: SequenceManager, scheduler: 'Scheduler') -> None: self.session_id = session_id self.seq_meta = seq_manager.seq_meta - self.status: MessageStatus = MessageStatus.RUNNING self.sequences: SeqMap = dict() self.seq_manager = seq_manager + self.scheduler = scheduler def add_sequence(self, token_ids: Tensor, @@ -264,6 +263,8 @@ def add_sequence(self, resp_cache: bool = False, preserve_cache: bool = False) -> 'SchedulerSequence': """Add a new message.""" + from lmdeploy.pytorch.paging.seq_states.states import build_seq_state + if sampling_param is None: sampling_param = SamplingParam() @@ -282,12 +283,22 @@ def add_sequence(self, mode=UpdateTokenMode.INPUTS, ) self.sequences[seq.seq_id] = seq - self.seq_manager.add_sequence(seq) + + # set status + # update seq manager + status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING + seq.set_state(build_seq_state(self.scheduler, seq, status)) + self.seq_manager.add_sequence(seq, status) + + # metrics + seq.record_event(EventType.QUEUED) + return seq def remove_sequence(self, seq: 'SchedulerSequence'): """Remove sequence.""" assert seq.seq_id in self.sequences + seq.state.free() self.sequences.pop(seq.seq_id) self.seq_manager.remove_sequence(seq) @@ -557,7 +568,6 @@ class SchedulerSequence: arrive_time: float = 0.0 output_start_pos: int = 0 meta: Any = None - _status: MessageStatus = field(default=MessageStatus.WAITING, init=False) num_ignored_history: int = 0 model_meta: Dict[str, Any] = None @@ -583,6 +593,7 @@ def __post_init__(self): self._num_images: int = len(self.history_embeddings) self._num_history_cross: int = 0 self._num_cross: int = self.history_multimodals.get_encoder_len(0, self._num_token_ids) + self._state = None @property def block_size(self) -> int: @@ -692,23 +703,21 @@ def num_blocks(self): return len(self.logical_blocks) @property - def seq_manager(self) -> SequenceManager: - """Sequence manager.""" - return self.session.seq_manager + def state(self) -> 'StateBase': + return self._state + + def set_state(self, state: 'StateBase'): + """Set state.""" + self._state = state @property def status(self): - return self._status + return self.state.status @property def return_logits(self): return self.sampling_param.out_logits - @status.setter - def status(self, value: MessageStatus): - self.seq_manager.update_sequence_status(self, value) - self._status = value - def num_all_cross_tokens(self): """Num of all cross tokens.""" return self._num_cross + self._num_history_cross diff --git a/lmdeploy/pytorch/paging/eviction_helper/__init__.py b/lmdeploy/pytorch/paging/eviction_helper/__init__.py index 9d82582f1e..6b5c44ff97 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/__init__.py +++ b/lmdeploy/pytorch/paging/eviction_helper/__init__.py @@ -1,4 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .recompute_eviction_helper import RecomputeEvictionHelper +from lmdeploy.utils import get_logger -__all__ = ['RecomputeEvictionHelper'] +logger = get_logger('lmdeploy') + + +def build_eviction_helper(scheduler, eviction_type: str): + """Build eviction helper.""" + if eviction_type == 'copy': + logger.warning('`copy` eviction has been deprecated, ' + 'use `recompute` instead.') + eviction_type = 'recompute' + if eviction_type == 'recompute': + from .recompute_eviction_helper import RecomputeEvictionHelper + return RecomputeEvictionHelper(scheduler) + else: + raise TypeError(f'Unknown eviction type: {eviction_type}') diff --git a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py index be0d09a5f9..bdded115dd 100644 --- a/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py +++ b/lmdeploy/pytorch/paging/eviction_helper/recompute_eviction_helper.py @@ -35,8 +35,7 @@ def _evict_for_seq_default(self, seq: SchedulerSequence, evictable_seqs: List[Sc if evict_seq.num_blocks == 0: continue - block_manager.free(evict_seq) - evict_seq.set_step(0) + evict_seq.state.free() num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) if num_req <= 0: success = True @@ -77,9 +76,7 @@ def _evict_for_ssm(self, seq: SchedulerSequence, evictable_seqs: List[SchedulerS continue # free sequence - block_manager.free(evict_seq) - evict_seq.set_step(0) - state_manager.free(evict_seq) + evict_seq.state.free() has_free_state = True num_req = (num_required_blocks - block_manager.get_num_free_gpu_blocks()) if num_req <= 0: diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 579d335d22..1359c88ae5 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -6,12 +6,13 @@ from typing import Dict, List from lmdeploy.messages import EventType, ScheduleMetrics -from lmdeploy.utils import get_logger, logging_timer +from lmdeploy.utils import get_logger from ..config import CacheConfig, SchedulerConfig from ..messages import MessageStatus, SchedulerSequence, SchedulerSession, SequenceManager, SequenceMeta from .block_manager import build_block_manager from .block_trie import BlockTrie +from .eviction_helper import build_eviction_helper from .state_manager import StateManager logger = get_logger('lmdeploy') @@ -55,7 +56,7 @@ def __init__( self.state_manager = StateManager(self.cache_config.num_state_caches) self.is_ssm = len(self.cache_config.states_shapes) > 0 - self.eviction_helper = self.build_eviction_helper(self.scheduler_config.eviction_type) + self.eviction_helper = build_eviction_helper(self, self.scheduler_config.eviction_type) seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) self.seq_manager = SequenceManager(seq_meta) @@ -67,9 +68,9 @@ def waiting(self): return list(seq_map.values()) @property - def running(self): + def ready(self): """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING) + seq_map = self.seq_manager.get_sequences(MessageStatus.READY) return list(seq_map.values()) @property @@ -79,21 +80,15 @@ def hanging(self): return list(seq_map.values()) @property - def locked(self): + def running(self): """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.LOCKED) - return list(seq_map.values()) - - @property - def waiting_migration(self): - """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING) return list(seq_map.values()) @property - def running_migration(self): + def migration_waiting(self): """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING_MIGRATION) + seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_WAITING) return list(seq_map.values()) @property @@ -102,26 +97,6 @@ def migration_done(self): seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) return list(seq_map.values()) - def build_eviction_helper(self, eviction_type: str): - if eviction_type == 'copy': - logger.warning('`copy` eviction has been deprecated, ' - 'use `recompute` instead.') - eviction_type = 'recompute' - if eviction_type == 'recompute': - from .eviction_helper import RecomputeEvictionHelper - return RecomputeEvictionHelper(self) - else: - raise TypeError(f'Unknown eviction type: {eviction_type}') - - def _set_message_status(self, message: SchedulerSequence, status: MessageStatus): - """Set status of message. - - Args: - message (SchedulerSequence): message to setup status. - status (MessageStatus): New message status. - """ - message.status = status - def add_session(self, session_id: int): """Add new session. @@ -129,32 +104,18 @@ def add_session(self, session_id: int): session_id (int): New session id. """ assert session_id not in self.sessions - session = SchedulerSession(session_id, seq_manager=self.seq_manager) + session = SchedulerSession(session_id, seq_manager=self.seq_manager, scheduler=self) self.sessions[session_id] = session return session - def add_sequence(self, seq: SchedulerSequence): - """Add sequence. - - Args: - seq (SchedulerSequence): New sequence. - """ - assert (seq.session_id in self.sessions), f'Unknown session id {seq.session_id}' - - # push message to waiting queue - self._set_message_status(seq, MessageStatus.WAITING) - - seq.record_event(EventType.QUEUED) - - @logging_timer('ScheduleMigration', logger) def _schedule_migration(self): - running_migration: SeqList = [] + migration_ready: SeqList = [] migrating_token_count = 0 def _to_running(seq: SchedulerSequence): """To running.""" - seq.status = MessageStatus.RUNNING_MIGRATION - running_migration.append(seq) + seq.state.activate() + migration_ready.append(seq) nonlocal migrating_token_count migrating_token_count += seq.num_token_ids @@ -169,28 +130,27 @@ def __evict_for_seq(seq: SchedulerSequence, waiting): def _reorder_migrating(): """Reorder waiting.""" - return sorted(self.waiting_migration, key=lambda seq: seq.arrive_time) + return sorted(self.migration_waiting, key=lambda seq: seq.arrive_time) - waiting_migration = _reorder_migrating() + migration_waiting = _reorder_migrating() - max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() - while len(waiting_migration) > 0 and len(running_migration) < max_batches: - seq = waiting_migration.pop(0) - self.block_trie.match(waiting_migration) - if not __evict_for_seq(seq, waiting_migration): + max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() + while len(migration_waiting) > 0 and len(migration_ready) < max_batches: + seq = migration_waiting.pop(0) + self.block_trie.match(migration_waiting) + if not __evict_for_seq(seq, migration_waiting): break # allocate session memory self.block_manager.allocate(seq) _to_running(seq) - return running_migration + return migration_ready - @logging_timer('SchedulePrefilling', logger) def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" - max_batches = self.scheduler_config.max_batches - self.num_running() - self.num_locked() + max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() eviction_helper = self.eviction_helper swap_out_map: Dict[int, int] = dict() swap_in_map: Dict[int, int] = dict() @@ -200,7 +160,7 @@ def _schedule_prefill(self, prealloc_size: int = 0): def _to_running(seq: SchedulerSequence): """To running.""" - seq.status = MessageStatus.RUNNING + seq.state.activate() running.append(seq) nonlocal token_count token_count += seq.num_token_ids @@ -243,11 +203,10 @@ def _reorder_waiting(): return running, swap_in_map, swap_out_map, copy_map - @logging_timer('ScheduleDecoding', logger) def _schedule_decoding(self, prealloc_size: int = 0): """Schedule decoding.""" - running = self.running + running = self.ready assert len(running) != 0 eviction_helper = self.eviction_helper @@ -272,27 +231,18 @@ def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): # 1. running for seq in running: - # token + n - num_required_blocks = self.block_manager.num_required_blocks(seq, prealloc_size) - if len(seq.logical_blocks) + num_required_blocks > self.block_manager.num_gpu_blocks: - # Reach max gpu cache size. - logger.warning(f'session[{seq.session_id}] ' - f'sequence[{seq.seq_id}] ' - 'reach max gpu size.') - self._set_message_status(seq, MessageStatus.ABORTED) - self.block_manager.free(seq) - seq.set_step(0) - continue + assert seq.num_blocks + num_required_blocks <= self.block_manager.num_gpu_blocks, ( + 'Sequence requires more blocks than total gpu blocks.') if not __evict_for_seq(seq, num_required_blocks): - self._set_message_status(seq, MessageStatus.WAITING) + seq.state.evict() continue self.block_manager.allocate(seq, prealloc_size) self.block_trie.allocate(seq) - return self.running, swap_in_map, swap_out_map, copy_map + return self.ready[:self.scheduler_config.max_batches], swap_in_map, swap_out_map, copy_map def schedule(self, is_prefill: bool, prealloc_size: int = 0): """Schedule inputs for next steps.""" @@ -304,37 +254,16 @@ def schedule(self, is_prefill: bool, prealloc_size: int = 0): return SchedulerOutput(running=running, swap_in_map=swap_in_map, swap_out_map=swap_out_map, copy_map=copy_map) - def _set_session_status(self, session_id: int, status: MessageStatus): - """Setup the status of session. + def stop_session(self, session_id: int): + """Stop session. Args: session_id (int): The session id. - status (MessageStatus): New status. """ assert session_id in self.sessions session = self.sessions[session_id] - session.status = status for seq in session.sequences.values(): - seq.status = status - - def stop_session(self, session_id: int): - """Stop session. - - Args: - session_id (int): The session id. - """ - self._set_session_status(session_id, MessageStatus.STOPPED) - - def _remove_sequence(self, seq: SchedulerSequence): - """Remove sequence(unsafe) - - Args: - seq (SchedulerSequence): sequence to remove - """ - self.block_manager.free(seq) - self.state_manager.free(seq) - seq.set_step(0) - seq.session.remove_sequence(seq) + seq.state.stop() def end_session(self, session_id: int): """End session. @@ -345,25 +274,20 @@ def end_session(self, session_id: int): session = self.sessions[session_id] seqs = list(session.sequences.values()) for seq in seqs: - self._remove_sequence(seq) + seq.state.stop() + session.remove_sequence(seq) self.sessions.pop(session_id) def has_unfinished(self): """Check if there are any unfinished message.""" - return self.has_running() or self.has_waiting() or self.has_migration_done() + return self.has_ready() or self.has_waiting() or self.has_migration_done() - def has_running(self): - return self.num_running() > 0 + def has_ready(self): + return self.num_ready() > 0 def has_waiting(self): return self.num_waiting() > 0 - def has_to_be_migrated(self): - return self.num_to_be_migrated() > 0 - - def has_migration_running(self): - return self.num_running() > 0 - def has_migration_waiting(self): return self.num_migration_waiting() > 0 @@ -374,71 +298,58 @@ def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def num_running(self): + def num_ready(self): """Num running.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING) + return self.seq_manager.num_sequences(MessageStatus.READY) def num_waiting(self): """Num waiting.""" return self.seq_manager.num_sequences(MessageStatus.WAITING) - def num_to_be_migrated(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.TO_BE_MIGRATED) - - def num_migration_locked(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_LOCKED) - - def num_migration_running(self): - """Num migration running.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING_MIGRATION) - def num_migration_done(self): """Num migration done.""" return self.seq_manager.num_sequences(MessageStatus.MIGRATION_DONE) def num_migration_waiting(self): """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.WAITING_MIGRATION) + return self.seq_manager.num_sequences(MessageStatus.MIGRATION_WAITING) - def num_locked(self): + def num_running(self): """Num locked.""" - return self.seq_manager.num_sequences(MessageStatus.LOCKED) + return self.seq_manager.num_sequences(MessageStatus.RUNNING) - def lock_running(self, running: SeqList): + def active_seqs(self, running: SeqList): """Lock running sequence.""" for seq in running: - if seq.status == MessageStatus.RUNNING: - self._set_message_status(seq, MessageStatus.LOCKED) + if seq.status == MessageStatus.READY: + seq.state.activate() - def unlock_running(self, locked: SeqList): - for seq in locked: - if seq.status == MessageStatus.LOCKED: - self._set_message_status(seq, MessageStatus.RUNNING) + def deactive_seqs(self, running: SeqList): + for seq in running: + if seq.status == MessageStatus.RUNNING: + seq.state.deactivate() - def lock_running_migration(self, running: SeqList): + def activate_migration_seqs(self, running: SeqList): """Lock running sequence.""" for seq in running: - if seq.status == MessageStatus.RUNNING_MIGRATION: - self._set_message_status(seq, MessageStatus.MIGRATION_LOCKED) + if seq.status == MessageStatus.MIGRATION_READY: + seq.state.activate() - def unlock_running_migration(self, locked: SeqList): + def deactivate_migration_seqs(self, running: SeqList): """Unlock running migration.""" - for seq in locked: - if seq.status == MessageStatus.MIGRATION_LOCKED: - self._set_message_status(seq, MessageStatus.MIGRATION_DONE) + for seq in running: + if seq.status == MessageStatus.MIGRATION_RUNNING: + seq.state.deactivate() def collect_migration_done(self): - migration_done = self.migration_done - for seq in migration_done: - self._set_message_status(seq, MessageStatus.RUNNING) + for seq in self.migration_done: + seq.state.activate() @property def schedule_metrics(self): return ScheduleMetrics( - active_seqs=self.num_locked(), - waiting_seqs=self.num_waiting() + self.num_running(), + active_seqs=self.num_running(), + waiting_seqs=self.num_waiting() + self.num_ready(), total_blocks=self.block_manager.num_gpu_blocks, free_blocks=self.block_manager.get_num_free_gpu_blocks(), ) diff --git a/lmdeploy/pytorch/paging/seq_states/__init__.py b/lmdeploy/pytorch/paging/seq_states/__init__.py new file mode 100644 index 0000000000..bba2109f8e --- /dev/null +++ b/lmdeploy/pytorch/paging/seq_states/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .states import StateBase, build_seq_state # noqa: F401 diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py new file mode 100644 index 0000000000..4ab7b1f154 --- /dev/null +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import TYPE_CHECKING + +from lmdeploy.pytorch.messages import MessageStatus, SchedulerSequence + +if TYPE_CHECKING: + from lmdeploy.pytorch.paging import Scheduler + + +def _free_seq(seq: SchedulerSequence, scheduler: 'Scheduler'): + """Free the sequence.""" + if seq.num_blocks > 0: + scheduler.block_manager.free(seq) + if seq.logical_state >= 0: + scheduler.state_manager.free(seq) + seq.set_step(0) + + +class StateBase: + status = None + _registry = dict() + + def __init_subclass__(cls, **kargs) -> None: + super().__init_subclass__(**kargs) + if cls.status: + cls._registry[cls.status] = cls + + @classmethod + def build(cls, scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> 'StateBase': + """Build sequence state.""" + if status not in cls._registry: + raise NotImplementedError(f'Unsupported status {status} for building seq state.') + return cls._registry[status](seq, scheduler) + + def __init__(self, seq: SchedulerSequence, scheduler: 'Scheduler'): + self.seq = seq + self.scheduler = scheduler + + def to_state(self, new_state): + """Transition to a new state.""" + self.scheduler.seq_manager.update_sequence_status(self.seq, new_state.status) + self.seq.set_state(new_state(self.seq, self.scheduler)) + + def evict(self): + """Evict the state.""" + raise NotImplementedError(f'evict not implemented for state {self.status}') + + def activate(self): + """Activate the state.""" + raise NotImplementedError(f'activate not implemented for state {self.status}') + + def deactivate(self): + """Deactivate the state.""" + raise NotImplementedError(f'deactivate not implemented for state {self.status}') + + def finish(self): + """Finish the state.""" + raise NotImplementedError(f'finish not implemented for state {self.status}') + + def stop(self): + """Stop the state.""" + self.to_state(StoppedState) + + def free(self): + """Free the state.""" + _free_seq(self.seq, self.scheduler) + + +class WaitingState(StateBase): + """State for waiting sequences.""" + status = MessageStatus.WAITING + + def activate(self): + """From WAITING to READY.""" + num_req_blocks = self.scheduler.block_manager.num_required_blocks(self.seq) + assert self.seq.num_blocks >= num_req_blocks + if self.scheduler.is_ssm: + assert self.seq.logical_state >= 0 + self.to_state(ReadyState) + + def evict(self): + self.to_state(WaitingState) + + +class ReadyState(StateBase): + """State for ready sequences.""" + status = MessageStatus.READY + + def activate(self): + """From READY to RUNNING.""" + self.to_state(RunningState) + + def evict(self): + self.to_state(WaitingState) + + +class StoppedState(StateBase): + """State for stopped sequences.""" + status = MessageStatus.STOPPED + + def activate(self): + """From STOPPED to WAITING.""" + assert self.seq.num_token_ids > 0 + self.to_state(WaitingState) + + +class RunningState(StateBase): + """State for running sequences.""" + status = MessageStatus.RUNNING + + def deactivate(self): + self.to_state(ReadyState) + + def finish(self): + if self.seq.preserve_cache: + self.to_state(ToBeMigratedState) + else: + self.to_state(StoppedState) + + +class ToBeMigratedState(StateBase): + """State for to be migrated sequences.""" + status = MessageStatus.TO_BE_MIGRATED + + +class MigrationWaitingState(StateBase): + """State for migration waiting sequences.""" + status = MessageStatus.MIGRATION_WAITING + + def activate(self): + self.to_state(MigrationReadyState) + + def evict(self): + self.to_state(MigrationWaitingState) + + +class MigrationReadyState(StateBase): + """State for migration ready sequences.""" + status = MessageStatus.MIGRATION_READY + + def activate(self): + self.to_state(MigrationRunningState) + + def evict(self): + self.to_state(MigrationWaitingState) + + +class MigrationDoneState(StateBase): + """State for migration done sequences.""" + status = MessageStatus.MIGRATION_DONE + + def finish(self): + self.to_state(ReadyState) + + +class MigrationRunningState(StateBase): + """State for migration running sequences.""" + status = MessageStatus.MIGRATION_RUNNING + + def finish(self): + self.to_state(MigrationDoneState) + + +def build_seq_state(scheduler: 'Scheduler', seq: 'SchedulerSequence', status: MessageStatus) -> StateBase: + """Build sequence state.""" + return StateBase.build(scheduler, seq, status) diff --git a/lmdeploy/pytorch/strategies/ar/sequence.py b/lmdeploy/pytorch/strategies/ar/sequence.py index 197217c8bb..de7b68e2de 100644 --- a/lmdeploy/pytorch/strategies/ar/sequence.py +++ b/lmdeploy/pytorch/strategies/ar/sequence.py @@ -125,10 +125,10 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d update_mode = UpdateTokenMode.DECODE if is_decoding else UpdateTokenMode.PREFILL for token, msg, stop, model_meta, routed_experts in zip(next_token_ids, running, stopped, model_metas, all_routed_experts): - if msg.status != MessageStatus.LOCKED: + if msg.status != MessageStatus.RUNNING: continue # fill token msg.update_token_ids(token, model_meta=model_meta, mode=update_mode, routed_experts=routed_experts) if stop: - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + msg.state.finish() diff --git a/lmdeploy/pytorch/strategies/ar_spec/sequence.py b/lmdeploy/pytorch/strategies/ar_spec/sequence.py index ba4236e988..2e272e1473 100644 --- a/lmdeploy/pytorch/strategies/ar_spec/sequence.py +++ b/lmdeploy/pytorch/strategies/ar_spec/sequence.py @@ -179,11 +179,11 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d msg = running[idx] stop = stopped[idx] model_meta = model_metas[idx] - if msg.status != MessageStatus.LOCKED: + if msg.status != MessageStatus.RUNNING: continue cur_draft_tokens = draft_token_ids[idx] # fill token msg.update_token_ids(token, draft_token_ids=cur_draft_tokens, model_meta=model_meta, mode=update_mode) if stop: msg.set_stop_pos(stop_pos[idx]) - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + msg.state.finish() diff --git a/lmdeploy/pytorch/strategies/dllm/sequence.py b/lmdeploy/pytorch/strategies/dllm/sequence.py index ab004a2b63..05c962632c 100644 --- a/lmdeploy/pytorch/strategies/dllm/sequence.py +++ b/lmdeploy/pytorch/strategies/dllm/sequence.py @@ -238,11 +238,11 @@ def update_running(self, running: SeqList, batched_outputs: BatchedOutputs, is_d stop = stopped[idx] model_meta = model_metas[idx] mask = dllm_mask[idx] - if msg.status != MessageStatus.LOCKED: + if msg.status != MessageStatus.RUNNING: continue # fill token msg.update_token_ids(token, dllm_mask=mask, model_meta=model_meta, mode=update_mode) if stop: msg.set_stop_pos(stop_pos[idx]) - msg.status = MessageStatus.TO_BE_MIGRATED if msg.preserve_cache else MessageStatus.STOPPED + msg.state.finish() diff --git a/tests/pytorch/paging/test_block_manager.py b/tests/pytorch/paging/test_block_manager.py index f74b6548cf..b08116d7f4 100644 --- a/tests/pytorch/paging/test_block_manager.py +++ b/tests/pytorch/paging/test_block_manager.py @@ -2,9 +2,10 @@ import pytest import torch -from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta -from lmdeploy.pytorch.paging.block_manager import DefaultBlockManager, WindowBlockManager +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig +from lmdeploy.pytorch.messages import SequenceMeta from lmdeploy.pytorch.paging.block_manager.base_block_manager import LogicalAllocator +from lmdeploy.pytorch.paging.scheduler import Scheduler # yapf: enable @@ -86,18 +87,39 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def block_mgr(self, num_cpu_blocks, num_gpu_blocks): - yield DefaultBlockManager(num_cpu_blocks, num_gpu_blocks) + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, + block_size=block_size, + num_cpu_blocks=num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks) + + @pytest.fixture + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') @pytest.fixture - def seq_manager(self, block_size): + def seq_meta(self, block_size): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy strategy = ARSequenceStrategy() - seq_meta = SequenceMeta(block_size, strategy=strategy) - yield SequenceManager(seq_meta) + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + @pytest.fixture + def block_mgr(self, scheduler): + yield scheduler.block_manager - def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_alloc(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # test alloc @@ -121,9 +143,9 @@ def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks): + def test_num_required_blocks(self, scheduler, block_mgr): from lmdeploy.pytorch.messages import InputEmbeddings - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = torch.tensor([1]) @@ -142,8 +164,8 @@ def test_num_required_blocks(self, block_mgr, seq_manager, num_gpu_blocks): num_required = block_mgr.num_required_blocks(msg) assert num_required == 3 - def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_append_slot(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # test append @@ -168,8 +190,8 @@ def test_append_slot(self, block_mgr, seq_manager, num_gpu_blocks): assert len(block_table) == 2 assert block_mgr.get_num_free_gpu_blocks() == num_gpu_blocks - 2 - def test_swap(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_swap(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = torch.tensor([1] * (block_size + 1)) @@ -227,18 +249,40 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def seq_manager(self, block_size): + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size, window_size): + yield CacheConfig(max_batches=max_batch_size, + block_size=block_size, + num_cpu_blocks=num_cpu_blocks, + num_gpu_blocks=num_gpu_blocks, + window_size=window_size) + + @pytest.fixture + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') + + @pytest.fixture + def seq_meta(self, block_size): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy strategy = ARSequenceStrategy() - seq_meta = SequenceMeta(block_size, strategy=strategy) - yield SequenceManager(seq_meta) + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) @pytest.fixture - def block_mgr(self, num_cpu_blocks, num_gpu_blocks, window_size): - yield WindowBlockManager(num_cpu_blocks, num_gpu_blocks, window_size) + def block_mgr(self, scheduler): + yield scheduler.block_manager - def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): - sess = SchedulerSession(0, seq_manager) + def test_alloc(self, scheduler, block_mgr, num_gpu_blocks): + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # test alloc @@ -262,8 +306,8 @@ def test_alloc(self, block_mgr, seq_manager, num_gpu_blocks): msg = sess.add_sequence(token_ids) assert not block_mgr.can_allocate(msg) - def test_win_alloc(self, block_mgr, seq_manager, num_gpu_blocks, window_size): - sess = SchedulerSession(0, seq_manager) + def test_win_alloc(self, scheduler, block_mgr, num_gpu_blocks, window_size): + sess = scheduler.add_session(0) # 2 win block token_ids = torch.tensor([1] * window_size) diff --git a/tests/pytorch/paging/test_block_trie.py b/tests/pytorch/paging/test_block_trie.py index 7d20c96dab..5736e4d006 100644 --- a/tests/pytorch/paging/test_block_trie.py +++ b/tests/pytorch/paging/test_block_trie.py @@ -1,10 +1,9 @@ import numpy as np import pytest -from lmdeploy.pytorch.config import CacheConfig -from lmdeploy.pytorch.messages import SchedulerSession, SequenceManager, SequenceMeta -from lmdeploy.pytorch.paging.block_manager import build_block_manager -from lmdeploy.pytorch.paging.block_trie import BlockTrie +from lmdeploy.pytorch.config import CacheConfig, SchedulerConfig +from lmdeploy.pytorch.messages import SequenceMeta +from lmdeploy.pytorch.paging import Scheduler class TestBlockTire: @@ -22,31 +21,45 @@ def num_gpu_blocks(self): yield 16 @pytest.fixture - def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield CacheConfig(max_batches=256, + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, block_size=block_size, num_cpu_blocks=num_cpu_blocks, num_gpu_blocks=num_gpu_blocks, enable_prefix_caching=True) @pytest.fixture - def block_mgr(self, cache_config): - yield build_block_manager(cache_config) + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') @pytest.fixture - def block_trie(self, cache_config, block_mgr): - yield BlockTrie(cache_config, block_mgr) - - @pytest.fixture - def seq_manager(self, block_size): + def seq_meta(self, block_size): from lmdeploy.pytorch.strategies.ar.sequence import ARSequenceStrategy strategy = ARSequenceStrategy() - seq_meta = SequenceMeta(block_size, strategy=strategy) - yield SequenceManager(seq_meta) + yield SequenceMeta(block_size, strategy=strategy) + + @pytest.fixture + def scheduler(self, cache_config, scheduler_config, seq_meta): + yield Scheduler(scheduler_config=scheduler_config, cache_config=cache_config, seq_meta=seq_meta) + + @pytest.fixture + def block_mgr(self, scheduler): + yield scheduler.block_manager + + @pytest.fixture + def block_trie(self, scheduler): + yield scheduler.block_trie - def test_allocate(self, block_trie, block_mgr, seq_manager): + def test_allocate(self, block_trie, block_mgr, scheduler): allocator = block_trie.allocator - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = ([1] * block_size + [2] * block_size) token_ids += [3] * (block_size // 2) @@ -83,9 +96,9 @@ def test_allocate(self, block_trie, block_mgr, seq_manager): assert node in block_trie.leaves assert len(block_trie.leaves) == 1 - def test_match(self, block_trie, block_mgr, seq_manager): + def test_match(self, block_trie, block_mgr, scheduler): allocator = block_trie.allocator - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size # initialize cache @@ -121,9 +134,9 @@ def test_match(self, block_trie, block_mgr, seq_manager): ref_cnt = allocator.get_ref_count(logical_blocks.get_real_blocks()) assert np.array_equal(ref_cnt, [4, 3]) - def test_evict(self, block_trie, seq_manager, num_gpu_blocks): + def test_evict(self, block_trie, scheduler, num_gpu_blocks): block_mgr = block_trie.block_manager - sess = SchedulerSession(0, seq_manager) + sess = scheduler.add_session(0) block_size = sess.seq_meta.block_size token_ids = ([1] * block_size * (num_gpu_blocks - 1)) token_ids += [2] * (block_size // 2) diff --git a/tests/pytorch/paging/test_scheduler.py b/tests/pytorch/paging/test_scheduler.py index a0acf5f054..3d07c2c019 100644 --- a/tests/pytorch/paging/test_scheduler.py +++ b/tests/pytorch/paging/test_scheduler.py @@ -21,15 +21,22 @@ def num_gpu_blocks(self): yield 4 @pytest.fixture - def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks): - yield CacheConfig(max_batches=256, + def max_batch_size(self): + yield 4 + + @pytest.fixture + def cache_config(self, block_size, num_cpu_blocks, num_gpu_blocks, max_batch_size): + yield CacheConfig(max_batches=max_batch_size, block_size=block_size, num_cpu_blocks=num_cpu_blocks, num_gpu_blocks=num_gpu_blocks) @pytest.fixture - def scheduler_config(self): - yield SchedulerConfig(max_batches=4, max_session_len=128, max_request_output_len=64, eviction_type='recompute') + def scheduler_config(self, max_batch_size): + yield SchedulerConfig(max_batches=max_batch_size, + max_session_len=128, + max_request_output_len=64, + eviction_type='recompute') @pytest.fixture def seq_meta(self, block_size): @@ -51,7 +58,6 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): num_blocks = 2 token_ids = torch.tensor([0] * block_size * num_blocks) seq = session.add_sequence(token_ids) - scheduler.add_sequence(seq) assert seq.status == MessageStatus.WAITING assert seq in scheduler.waiting @@ -59,7 +65,7 @@ def test_schedule_base(self, scheduler, block_size, num_gpu_blocks): output = scheduler.schedule(is_prefill=True) block_tables = scheduler.get_block_tables(output.running) - assert seq.status == MessageStatus.RUNNING + assert seq.status == MessageStatus.READY assert seq in output.running assert len(block_tables) == 1 assert len(block_tables[0]) == num_blocks @@ -73,38 +79,34 @@ def test_update(self, scheduler, block_size, num_gpu_blocks): session1 = scheduler.add_session(session_id1) token_ids1 = torch.tensor([0] * block_size * 1) seq1 = session1.add_sequence(token_ids1) - scheduler.add_sequence(seq1) session_id2 = 1 session2 = scheduler.add_session(session_id2) token_ids2 = torch.tensor([0] * block_size * 2) seq2 = session2.add_sequence(token_ids2) - scheduler.add_sequence(seq2) token_ids3 = torch.tensor([0] * block_size * 3) seq3 = session2.add_sequence(token_ids3) - scheduler.add_sequence(seq3) scheduler.schedule(is_prefill=True) - assert seq1.status == MessageStatus.RUNNING - assert seq2.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY + assert seq2.status == MessageStatus.READY assert seq3.status == MessageStatus.WAITING # stop seq - seq1.status = MessageStatus.STOPPED - assert len(scheduler.running) == 1 + seq1.state.stop() + assert len(scheduler.ready) == 1 assert seq1 in scheduler.hanging # end seq - seq1.status = MessageStatus.ENDED - scheduler._remove_sequence(seq1) + seq1.session.remove_sequence(seq1) assert session_id1 in scheduler.sessions - assert seq1 not in scheduler.running + assert seq1 not in scheduler.ready assert seq1 not in scheduler.hanging assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 2 # stop session scheduler.stop_session(session_id2) - assert len(scheduler.running) == 0 + assert len(scheduler.ready) == 0 assert len(scheduler.waiting) == 0 assert len(scheduler.hanging) == 2 @@ -122,25 +124,22 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): # test: add 3 seq token_ids1 = torch.tensor([0] * block_size * 1) seq1 = session.add_sequence(token_ids1) - scheduler.add_sequence(seq1) token_ids2 = torch.tensor([0] * block_size * 2) seq2 = session.add_sequence(token_ids2) - scheduler.add_sequence(seq2) token_ids3 = torch.tensor([0] * block_size * 3) seq3 = session.add_sequence(token_ids3) - scheduler.add_sequence(seq3) scheduler.schedule(is_prefill=True) # seq1: 1 running gpu # seq2: 2 running gpu # seq3: 3 waiting empty - assert seq1.status == MessageStatus.RUNNING - assert seq2.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY + assert seq2.status == MessageStatus.READY assert seq3.status == MessageStatus.WAITING assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - 3 # test: waiting alloc - seq2.status = MessageStatus.STOPPED - assert len(scheduler.running) == 1 + seq2.state.stop() + assert len(scheduler.ready) == 1 assert len(scheduler.waiting) == 1 assert len(scheduler.hanging) == 1 @@ -148,17 +147,16 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): # seq1: 1 running gpu # seq2: 2 hanging cpu # seq3: 3 running gpu - assert seq1.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY assert seq2.status == MessageStatus.STOPPED - assert seq3.status == MessageStatus.RUNNING + assert seq3.status == MessageStatus.READY assert block_manager.get_num_free_gpu_blocks() == 0 # test: waiting append token - seq2.status = MessageStatus.WAITING - seq3.status = MessageStatus.ENDED - scheduler._remove_sequence(seq3) + seq2.state.activate() + seq3.session.remove_sequence(seq3) seq2.update_token_ids(torch.tensor([1] * block_size)) - assert len(scheduler.running) == 1 + assert len(scheduler.ready) == 1 assert len(scheduler.waiting) == 1 assert len(scheduler.hanging) == 0 @@ -166,18 +164,18 @@ def test_evict(self, scheduler, block_size, num_gpu_blocks, num_cpu_blocks): # seq1: 1 running gpu # seq2: 3 running gpu # seq3: 3 nan - assert seq1.status == MessageStatus.RUNNING - assert seq2.status == MessageStatus.RUNNING + assert seq1.status == MessageStatus.READY + assert seq2.status == MessageStatus.READY assert block_manager.get_num_free_gpu_blocks() == 0 # test running append seq1.update_token_ids(torch.tensor([1] * block_size)) seq2.update_token_ids(torch.tensor([1] * block_size)) - assert len(scheduler.running) == 2 + assert len(scheduler.ready) == 2 scheduler.schedule(is_prefill=False) # seq1: 1 waiting cpu # seq2: 4 running gpu # seq3: 3 nan assert seq1.status == MessageStatus.WAITING - assert seq2.status == MessageStatus.RUNNING + assert seq2.status == MessageStatus.READY assert block_manager.get_num_free_gpu_blocks() == 0 From d3259fb5f2610e266079b176b9485333f43e5e22 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 27 Nov 2025 21:19:02 +0800 Subject: [PATCH 02/12] fix pd, better property --- lmdeploy/pytorch/paging/scheduler.py | 106 ++++++++----------- lmdeploy/pytorch/paging/seq_states/states.py | 9 ++ 2 files changed, 53 insertions(+), 62 deletions(-) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 1359c88ae5..21917881ab 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -61,41 +61,54 @@ def __init__( seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) self.seq_manager = SequenceManager(seq_meta) - @property - def waiting(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.WAITING) - return list(seq_map.values()) + @staticmethod + def create_status_list_property(status: MessageStatus): + """Create status list property.""" - @property - def ready(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.READY) - return list(seq_map.values()) + def _get_status_list(self): + seq_map = self.seq_manager.get_sequences(status) + return list(seq_map.values()) - @property - def hanging(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.STOPPED) - return list(seq_map.values()) + return property(_get_status_list) - @property - def running(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.RUNNING) - return list(seq_map.values()) + @staticmethod + def create_num_status_method(status: MessageStatus): + """Create num status method.""" - @property - def migration_waiting(self): - """Get migration sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_WAITING) - return list(seq_map.values()) + def _num_status(self): + return self.seq_manager.num_sequences(status) - @property - def migration_done(self): - """Get waiting sequence.""" - seq_map = self.seq_manager.get_sequences(MessageStatus.MIGRATION_DONE) - return list(seq_map.values()) + return _num_status + + @staticmethod + def create_has_status_method(status: MessageStatus): + """Create has status method.""" + + def _has_status(self): + return self.seq_manager.num_sequences(status) > 0 + + return _has_status + + # status list properties + waiting = create_status_list_property(MessageStatus.WAITING) + ready = create_status_list_property(MessageStatus.READY) + hanging = create_status_list_property(MessageStatus.STOPPED) + running = create_status_list_property(MessageStatus.RUNNING) + migration_waiting = create_status_list_property(MessageStatus.MIGRATION_WAITING) + migration_done = create_status_list_property(MessageStatus.MIGRATION_DONE) + + # num status methods + num_waiting = create_num_status_method(MessageStatus.WAITING) + num_ready = create_num_status_method(MessageStatus.READY) + num_running = create_num_status_method(MessageStatus.RUNNING) + num_migration_waiting = create_num_status_method(MessageStatus.MIGRATION_WAITING) + num_migration_done = create_num_status_method(MessageStatus.MIGRATION_DONE) + + # has status methods + has_waiting = create_has_status_method(MessageStatus.WAITING) + has_ready = create_has_status_method(MessageStatus.READY) + has_migration_waiting = create_has_status_method(MessageStatus.MIGRATION_WAITING) + has_migration_done = create_has_status_method(MessageStatus.MIGRATION_DONE) def add_session(self, session_id: int): """Add new session. @@ -274,6 +287,7 @@ def end_session(self, session_id: int): session = self.sessions[session_id] seqs = list(session.sequences.values()) for seq in seqs: + # stop session so it won't get scheduled again seq.state.stop() session.remove_sequence(seq) self.sessions.pop(session_id) @@ -282,42 +296,10 @@ def has_unfinished(self): """Check if there are any unfinished message.""" return self.has_ready() or self.has_waiting() or self.has_migration_done() - def has_ready(self): - return self.num_ready() > 0 - - def has_waiting(self): - return self.num_waiting() > 0 - - def has_migration_waiting(self): - return self.num_migration_waiting() > 0 - - def has_migration_done(self): - return self.num_migration_done() > 0 - def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def num_ready(self): - """Num running.""" - return self.seq_manager.num_sequences(MessageStatus.READY) - - def num_waiting(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.WAITING) - - def num_migration_done(self): - """Num migration done.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_DONE) - - def num_migration_waiting(self): - """Num waiting.""" - return self.seq_manager.num_sequences(MessageStatus.MIGRATION_WAITING) - - def num_running(self): - """Num locked.""" - return self.seq_manager.num_sequences(MessageStatus.RUNNING) - def active_seqs(self, running: SeqList): """Lock running sequence.""" for seq in running: diff --git a/lmdeploy/pytorch/paging/seq_states/states.py b/lmdeploy/pytorch/paging/seq_states/states.py index 4ab7b1f154..1f44f02111 100644 --- a/lmdeploy/pytorch/paging/seq_states/states.py +++ b/lmdeploy/pytorch/paging/seq_states/states.py @@ -122,6 +122,9 @@ class ToBeMigratedState(StateBase): """State for to be migrated sequences.""" status = MessageStatus.TO_BE_MIGRATED + def finish(self): + self.to_state(StoppedState) + class MigrationWaitingState(StateBase): """State for migration waiting sequences.""" @@ -149,6 +152,9 @@ class MigrationDoneState(StateBase): """State for migration done sequences.""" status = MessageStatus.MIGRATION_DONE + def activate(self): + self.to_state(ReadyState) + def finish(self): self.to_state(ReadyState) @@ -157,6 +163,9 @@ class MigrationRunningState(StateBase): """State for migration running sequences.""" status = MessageStatus.MIGRATION_RUNNING + def deactivate(self): + self.to_state(MigrationDoneState) + def finish(self): self.to_state(MigrationDoneState) From 59138541055f00844acdfadc90ca6a9e7b2ad5cb Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 27 Nov 2025 21:21:59 +0800 Subject: [PATCH 03/12] skip decoding warmup --- lmdeploy/pytorch/engine/model_agent.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index c6c9345539..6be141fd7f 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -440,6 +440,9 @@ def warmup(self): # warmup decoding(with cuda graph) capture_batch_sizes = self.patched_model.get_capture_batch_sizes() capture_batch_sizes = sorted(capture_batch_sizes, reverse=True) + if self.cache_config.role == EngineRole.Prefill: + # do not warmup decoding for prefill engine + capture_batch_sizes = [] for num_tokens in capture_batch_sizes: inputs = self.inputs_strategy.make_dummy(num_tokens, is_decoding=True, From b7b9f4b56fe84192cb5a3e65f9375e7944b9ad9b Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 1 Dec 2025 19:58:48 +0800 Subject: [PATCH 04/12] rename --- lmdeploy/pytorch/engine/engine.py | 16 ++++--------- lmdeploy/pytorch/engine/request.py | 2 +- lmdeploy/pytorch/messages.py | 10 ++++---- lmdeploy/pytorch/paging/scheduler.py | 35 +++++++++++++--------------- 4 files changed, 27 insertions(+), 36 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 609f8fa274..2fe9017d88 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -656,11 +656,10 @@ def __update_max_new_tokens(msg): scheduler = self.scheduler for req in reqs: session_id = req.data['session_id'] - if scheduler is None: + sess = scheduler.sessions.get(session_id, None) + if sess is None: self._response(req.resp, ResponseType.SESSION_NOT_EXIST) continue - session_id = req.data['session_id'] - sess = scheduler.sessions[session_id] # TODO: support 1 session n sequence sampling_param = req.data['sampling_param'] if len(sess.sequences) == 0: @@ -675,7 +674,6 @@ def __update_max_new_tokens(msg): resp_cache=req.data.get('with_cache'), preserve_cache=req.data.get('preserve_cache')) msg = next(iter(sess.sequences.values())) - __update_max_new_tokens(msg) if migration_request: self.migration_event.set() else: @@ -688,8 +686,8 @@ def __update_max_new_tokens(msg): ) msg.sampling_param = sampling_param msg.state.activate() - __update_max_new_tokens(msg) + __update_max_new_tokens(msg) msg.resp = req.resp @property @@ -697,10 +695,6 @@ def model_config(self) -> ModelConfig: """Model config.""" return self.executor.model_config - @property - def gpu_count(self): - return self.dist_config.world_size - @property def torch_int_dtype(self): """Return int32 for cuda, int64 for others.""" @@ -1309,8 +1303,8 @@ async def async_loop(self): forward_event=forward_event, has_runable_event=has_runable_event, inputs_maker=inputs_maker) - except Exception as e: - logger.exception(f'exception happened: {type(e)} {e}') + except Exception: + logger.exception('Engine main loop failed.') finally: self._loop_finally() diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 466e102e22..268d9556dd 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -108,7 +108,7 @@ def _gather_request(self, req_types: List[RequestType], data: List[Any]): resps = [] for rtype, rdata in zip(req_types, data): event = asyncio.Event() - resp = Response(type=ResponseType.HANDLER_NOT_EXIST, + resp = Response(type=ResponseType.INTERNAL_ENGINE_ERROR, sender_id=self.sender_id, event=event, data=None, diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 1fd26ea2cd..699988f9b6 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import enum +from collections import defaultdict from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -179,9 +180,7 @@ class SequenceManager: def __init__(self, seq_meta: SequenceMeta) -> None: self._seq_map: SeqMap = dict() - self._status_seq_map: Dict[MessageStatus, SeqMap] = dict() - for status in MessageStatus: - self._status_seq_map[status] = dict() + self._status_seq_map: Dict[MessageStatus, SeqMap] = defaultdict(dict) self.seq_meta = seq_meta self._seq_count = 0 @@ -203,9 +202,10 @@ def num_sequences(self, status: MessageStatus): """Num sequences.""" return len(self.get_sequences(status)) - def add_sequence(self, seq: 'SchedulerSequence', status: MessageStatus): + def add_sequence(self, seq: 'SchedulerSequence'): """Add sequence.""" seq_id = seq.seq_id + status = seq.status status_map = self._status_seq_map[status] self._seq_map[seq_id] = seq status_map[seq_id] = seq @@ -288,7 +288,7 @@ def add_sequence(self, # update seq manager status = MessageStatus.WAITING if migration_request is None else MessageStatus.MIGRATION_WAITING seq.set_state(build_seq_state(self.scheduler, seq, status)) - self.seq_manager.add_sequence(seq, status) + self.seq_manager.add_sequence(seq) # metrics seq.record_event(EventType.QUEUED) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 637dfa5ab6..e60fdc358f 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -17,6 +17,7 @@ logger = get_logger('lmdeploy') +MapType = Dict[int, int] SeqList = List[SchedulerSequence] @@ -25,9 +26,9 @@ class SchedulerOutput: """Output of schedule.""" running: SeqList - swap_in_map: Dict[int, int] - swap_out_map: Dict[int, int] - copy_map: Dict[int, int] + swap_in_map: MapType + swap_out_map: MapType + copy_map: MapType class Scheduler: @@ -165,9 +166,9 @@ def _schedule_prefill(self, prealloc_size: int = 0): max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running() eviction_helper = self.eviction_helper - swap_out_map: Dict[int, int] = dict() - swap_in_map: Dict[int, int] = dict() - copy_map: Dict[int, int] = dict() + swap_out_map: MapType = dict() + swap_in_map: MapType = dict() + copy_map: MapType = dict() running: SeqList = [] token_count = 0 @@ -227,9 +228,9 @@ def _reorder_running(): assert len(running) != 0 eviction_helper = self.eviction_helper - swap_out_map: Dict[int, int] = dict() - swap_in_map: Dict[int, int] = dict() - copy_map: Dict[int, int] = dict() + swap_out_map: MapType = dict() + swap_in_map: MapType = dict() + copy_map: MapType = dict() def __evict_for_seq(seq: SchedulerSequence, num_required_blocks: int): """Evict until can append.""" @@ -312,28 +313,24 @@ def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def active_seqs(self, running: SeqList): + def active_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.READY): """Lock running sequence.""" for seq in running: - if seq.status == MessageStatus.READY: + if seq.status == filter_status: seq.state.activate() - def deactive_seqs(self, running: SeqList): + def deactive_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.RUNNING): for seq in running: - if seq.status == MessageStatus.RUNNING: + if seq.status == filter_status: seq.state.deactivate() def activate_migration_seqs(self, running: SeqList): """Lock running sequence.""" - for seq in running: - if seq.status == MessageStatus.MIGRATION_READY: - seq.state.activate() + return self.active_seqs(running, filter_status=MessageStatus.MIGRATION_READY) def deactivate_migration_seqs(self, running: SeqList): """Unlock running migration.""" - for seq in running: - if seq.status == MessageStatus.MIGRATION_RUNNING: - seq.state.deactivate() + return self.deactive_seqs(running, filter_status=MessageStatus.MIGRATION_RUNNING) def collect_migration_done(self): for seq in self.migration_done: From 686452aaee8a98903511483fa1f078fbe6ff1578 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Mon, 1 Dec 2025 21:23:41 +0800 Subject: [PATCH 05/12] add more profile logs --- lmdeploy/pytorch/engine/engine.py | 3 ++- lmdeploy/pytorch/paging/scheduler.py | 4 ++++ lmdeploy/pytorch/strategies/ar/sampling.py | 2 ++ lmdeploy/pytorch/strategies/dllm/sampling.py | 3 +++ 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 2fe9017d88..9de78855ca 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -767,7 +767,7 @@ def __has_values(input_multimodals): return vision_embedding_inputs @torch.inference_mode() - @record_function('CreateModelInputs') + @record_function('create_model_inputs') def create_model_inputs(self, messages: SeqList, is_prefill: bool): """Create model inputs from messages. @@ -932,6 +932,7 @@ def _make_infer_outputs( outputs[session_id].logits = logits.split(seq_length)[idx] return outputs + @record_function('make_forward_inputs') def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): """Make forward inputs.""" diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index e60fdc358f..bbf7ff903a 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from typing import Dict, List +from torch.profiler import record_function + from lmdeploy.messages import EventType, ScheduleMetrics from lmdeploy.utils import get_logger @@ -161,6 +163,7 @@ def _reorder_migrating(): return migration_ready + @record_function('schedule_prefill') def _schedule_prefill(self, prealloc_size: int = 0): """Schedule for prefilling.""" @@ -217,6 +220,7 @@ def _reorder_waiting(): return running, swap_in_map, swap_out_map, copy_map + @record_function('schedule_decoding') def _schedule_decoding(self, prealloc_size: int = 0): """Schedule decoding.""" diff --git a/lmdeploy/pytorch/strategies/ar/sampling.py b/lmdeploy/pytorch/strategies/ar/sampling.py index 3b97940e9b..b37ce85e8f 100644 --- a/lmdeploy/pytorch/strategies/ar/sampling.py +++ b/lmdeploy/pytorch/strategies/ar/sampling.py @@ -2,6 +2,7 @@ from typing import List import torch +from torch.profiler import record_function from lmdeploy.pytorch.engine.logits_process import SamplingInputs from lmdeploy.pytorch.messages import SchedulerSequence @@ -41,6 +42,7 @@ def __init__(self, pad_token_id: int) -> None: self.pad_token_id = pad_token_id self.session_to_cleanup = [] + @record_function('make_sampling_inputs') def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: """Create sampling inputs from the sequences.""" batch_size = len(seqs) diff --git a/lmdeploy/pytorch/strategies/dllm/sampling.py b/lmdeploy/pytorch/strategies/dllm/sampling.py index 45048e25a5..c181704b0a 100644 --- a/lmdeploy/pytorch/strategies/dllm/sampling.py +++ b/lmdeploy/pytorch/strategies/dllm/sampling.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List +from torch.profiler import record_function + from lmdeploy.pytorch.engine.logits_process import SamplingInputs from lmdeploy.pytorch.messages import SchedulerSequence @@ -16,6 +18,7 @@ def __init__(self, pad_token_id: int, dllm_block_length: int) -> None: super().__init__(pad_token_id) self.dllm_block_length = dllm_block_length + @record_function('make_sampling_inputs') def make_sampling_inputs(self, seqs: SeqList) -> SamplingInputs: """Create sampling inputs from the sequences.""" out = super().make_sampling_inputs(seqs) From 3385d04a39e193bcda4c213d77eafc387e223fe4 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 2 Dec 2025 14:09:34 +0800 Subject: [PATCH 06/12] add config builder --- lmdeploy/pytorch/engine/config_builder.py | 106 +++++++++++++++++++ lmdeploy/pytorch/engine/engine.py | 119 +++------------------- 2 files changed, 119 insertions(+), 106 deletions(-) create mode 100644 lmdeploy/pytorch/engine/config_builder.py diff --git a/lmdeploy/pytorch/engine/config_builder.py b/lmdeploy/pytorch/engine/config_builder.py new file mode 100644 index 0000000000..d5c0fd7241 --- /dev/null +++ b/lmdeploy/pytorch/engine/config_builder.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os + +from lmdeploy.messages import PytorchEngineConfig, SpeculativeConfig +from lmdeploy.pytorch.config import (BackendConfig, CacheConfig, DistConfig, MiscConfig, SchedulerConfig, + SpecDecodeConfig) +from lmdeploy.utils import get_logger, get_max_batch_size, get_model + + +class ConfigBuilder: + + @staticmethod + def update_engine_config(engine_config: PytorchEngineConfig): + """Update pytorch engine config.""" + logger = get_logger('lmdeploy') + + # make sure engine exits + if engine_config is None: + engine_config = PytorchEngineConfig() + else: + engine_config = copy.deepcopy(engine_config) + + if engine_config.max_batch_size is None: + engine_config.max_batch_size = get_max_batch_size(engine_config.device_type) + + if engine_config.dllm_block_length is not None: + max_prefill_token_num = engine_config.max_prefill_token_num + max_batch_size = engine_config.max_batch_size + if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num: + engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length + logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} ' + f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' + f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') + + if engine_config.dp != 1: + if engine_config.tp == 1 and engine_config.ep == 1: + engine_config.dp = 1 + engine_config.dp_rank = 0 + + return engine_config + + @staticmethod + def build_scheduler_config(engine_config: PytorchEngineConfig): + """Build scheduler config.""" + scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size, + max_session_len=engine_config.session_len, + prefill_interval=engine_config.prefill_interval) + return scheduler_config + + @staticmethod + def build_cache_config(engine_config: PytorchEngineConfig): + """Build cache config.""" + cache_config = CacheConfig(max_batches=engine_config.max_batch_size, + block_size=engine_config.block_size, + num_cpu_blocks=engine_config.num_cpu_blocks, + num_gpu_blocks=engine_config.num_gpu_blocks, + cache_max_entry_count=engine_config.cache_max_entry_count, + max_prefill_token_num=engine_config.max_prefill_token_num, + enable_prefix_caching=engine_config.enable_prefix_caching, + quant_policy=engine_config.quant_policy, + device_type=engine_config.device_type, + migration_backend=engine_config.migration_backend, + role=engine_config.role) + return cache_config + + @staticmethod + def build_backend_config(engine_config: PytorchEngineConfig): + """Build backend config.""" + backend_config = BackendConfig( + eager_mode=engine_config.eager_mode, + device_type=engine_config.device_type, + ) + return backend_config + + @staticmethod + def build_dist_config(engine_config: PytorchEngineConfig): + """Build dist config.""" + dist_config = DistConfig.from_engine_config(engine_config=engine_config) + return dist_config + + @staticmethod + def build_misc_config(engine_config: PytorchEngineConfig): + """Build misc config.""" + misc_config = MiscConfig.from_engine_config(engine_config) + return misc_config + + @staticmethod + def build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig, + cache_config: CacheConfig): + """Build spec decode config.""" + specdecode_config = None + if speculative_config is not None: + draft_model = speculative_config.model + if draft_model and not os.path.exists(speculative_config.model): + draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision) + + specdecode_config = SpecDecodeConfig.from_config( + method=speculative_config.method, + num_speculative_tokens=speculative_config.num_speculative_tokens, + model=draft_model, + target_model=target_model, + target_cache_cfg=cache_config, + dtype=engine_config.dtype, + ) + return specdecode_config diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 9de78855ca..a28c72bf5d 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio -import copy import gc import logging import os @@ -18,15 +17,16 @@ from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch -from lmdeploy.utils import get_logger, get_max_batch_size, get_model +from lmdeploy.utils import get_logger, get_model from ..adapter.adapter import AdapterManager -from ..config import BackendConfig, CacheConfig, DistConfig, MiscConfig, ModelConfig, SchedulerConfig, SpecDecodeConfig +from ..config import CacheConfig, ModelConfig from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode from ..model_inputs import ModelInputs, VisionModelInputs from ..paging import Scheduler from ..strategies import build_strategy_factory from .base import EngineBase +from .config_builder import ConfigBuilder from .engine_checker import EngineChecker from .executor import build_executor from .model_agent import BatchedOutputs @@ -75,99 +75,6 @@ def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): return torch.as_tensor(out, dtype=dtype) -def _update_engine_config(engine_config: PytorchEngineConfig): - """Update pytorch engine config.""" - # make sure engine exits - if engine_config is None: - engine_config = PytorchEngineConfig() - else: - engine_config = copy.deepcopy(engine_config) - - if engine_config.max_batch_size is None: - engine_config.max_batch_size = get_max_batch_size(engine_config.device_type) - - if engine_config.dllm_block_length is not None: - max_prefill_token_num = engine_config.max_prefill_token_num - max_batch_size = engine_config.max_batch_size - if max_batch_size * engine_config.dllm_block_length > max_prefill_token_num: - engine_config.max_batch_size = max_prefill_token_num // engine_config.dllm_block_length - logger.warning(f'Update max_batch_size to {engine_config.max_batch_size} ' - f'since dllm_block_length({engine_config.dllm_block_length}) * max_batch_size ' - f'({max_batch_size}) > max_prefill_token_num ({max_prefill_token_num}).') - - if engine_config.dp != 1: - if engine_config.tp == 1 and engine_config.ep == 1: - engine_config.dp = 1 - engine_config.dp_rank = 0 - - return engine_config - - -def _build_scheduler_config(engine_config: PytorchEngineConfig): - """Build scheduler config.""" - scheduler_config = SchedulerConfig(max_batches=engine_config.max_batch_size, - max_session_len=engine_config.session_len, - prefill_interval=engine_config.prefill_interval) - return scheduler_config - - -def _build_cache_config(engine_config: PytorchEngineConfig): - """Build cache config.""" - cache_config = CacheConfig(max_batches=engine_config.max_batch_size, - block_size=engine_config.block_size, - num_cpu_blocks=engine_config.num_cpu_blocks, - num_gpu_blocks=engine_config.num_gpu_blocks, - cache_max_entry_count=engine_config.cache_max_entry_count, - max_prefill_token_num=engine_config.max_prefill_token_num, - enable_prefix_caching=engine_config.enable_prefix_caching, - quant_policy=engine_config.quant_policy, - device_type=engine_config.device_type, - migration_backend=engine_config.migration_backend, - role=engine_config.role) - return cache_config - - -def _build_backend_config(engine_config: PytorchEngineConfig): - """Build backend config.""" - backend_config = BackendConfig( - eager_mode=engine_config.eager_mode, - device_type=engine_config.device_type, - ) - return backend_config - - -def _build_dist_config(engine_config: PytorchEngineConfig): - """Build dist config.""" - dist_config = DistConfig.from_engine_config(engine_config=engine_config) - return dist_config - - -def _build_misc_config(engine_config: PytorchEngineConfig): - """Build misc config.""" - misc_config = MiscConfig.from_engine_config(engine_config) - return misc_config - - -def _build_specdecode_config(target_model, speculative_config: SpeculativeConfig, engine_config: PytorchEngineConfig, - cache_config: CacheConfig): - """Build spec decode config.""" - specdecode_config = None - if speculative_config is not None: - draft_model = speculative_config.model - if draft_model and not os.path.exists(speculative_config.model): - draft_model = get_model(draft_model, engine_config.download_dir, engine_config.revision) - - specdecode_config = SpecDecodeConfig.from_config( - method=speculative_config.method, - num_speculative_tokens=speculative_config.num_speculative_tokens, - model=draft_model, - target_model=target_model, - target_cache_cfg=cache_config, - dtype=engine_config.dtype, - ) - return specdecode_config - - def _build_seq_meta(cache_config: CacheConfig, strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta @@ -202,11 +109,11 @@ def clear(self): class RunableEventBase: """Runable event base.""" - async def wait(self, idx: int): + async def wait(self): """Wait event.""" raise NotImplementedError('Not implemented.') - def set(self, idx: int = None): + def set(self): """Set event.""" raise NotImplementedError('Not implemented.') @@ -365,7 +272,7 @@ def __init__( speculative_config: SpeculativeConfig = None, ) -> None: # make sure engine config exist - engine_config = _update_engine_config(engine_config) + engine_config = ConfigBuilder.update_engine_config(engine_config) # frequently gc would cause latency spike # default threshold (700, 10, 10) @@ -393,14 +300,14 @@ def __init__( checker.handle() # build configs - scheduler_config = _build_scheduler_config(engine_config) - cache_config = _build_cache_config(engine_config) - backend_config = _build_backend_config(engine_config) - dist_config = _build_dist_config(engine_config) - misc_config = _build_misc_config(engine_config) - + scheduler_config = ConfigBuilder.build_scheduler_config(engine_config) + cache_config = ConfigBuilder.build_cache_config(engine_config) + backend_config = ConfigBuilder.build_backend_config(engine_config) + dist_config = ConfigBuilder.build_dist_config(engine_config) + misc_config = ConfigBuilder.build_misc_config(engine_config) # spec decode - self.specdecode_config = _build_specdecode_config(model_path, speculative_config, engine_config, cache_config) + self.specdecode_config = ConfigBuilder.build_specdecode_config(model_path, speculative_config, engine_config, + cache_config) # build model agent self.executor = build_executor( From c30b4218a220da3f57cbe4f2f53da0a1fe600fe6 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 2 Dec 2025 21:51:00 +0800 Subject: [PATCH 07/12] add engine_loop and input_maker --- lmdeploy/pytorch/engine/engine.py | 808 +------------------------ lmdeploy/pytorch/engine/engine_loop.py | 514 ++++++++++++++++ lmdeploy/pytorch/engine/input_maker.py | 415 +++++++++++++ lmdeploy/pytorch/engine/request.py | 2 +- lmdeploy/pytorch/paging/scheduler.py | 29 +- 5 files changed, 983 insertions(+), 785 deletions(-) create mode 100644 lmdeploy/pytorch/engine/engine_loop.py create mode 100644 lmdeploy/pytorch/engine/input_maker.py diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index a28c72bf5d..7c2c30d473 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -1,43 +1,35 @@ # Copyright (c) OpenMMLab. All rights reserved. import asyncio import gc -import logging import os -import time from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Union import numpy as np import torch -from torch.profiler import record_function from lmdeploy.messages import PytorchEngineConfig, RequestMetrics, ResponseType, SpeculativeConfig from lmdeploy.pytorch.disagg.config import EngineRole from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection from lmdeploy.pytorch.disagg.conn.protocol import (DistServeConnectionRequest, DistServeDropConnectionRequest, DistServeInitRequest) -from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch from lmdeploy.utils import get_logger, get_model from ..adapter.adapter import AdapterManager from ..config import CacheConfig, ModelConfig -from ..messages import MessageStatus, SchedulerSequence, UpdateTokenMode -from ..model_inputs import ModelInputs, VisionModelInputs +from ..messages import SchedulerSequence, UpdateTokenMode from ..paging import Scheduler from ..strategies import build_strategy_factory from .base import EngineBase from .config_builder import ConfigBuilder from .engine_checker import EngineChecker from .executor import build_executor -from .model_agent import BatchedOutputs from .request import Request, RequestManager, RequestType, Response logger = get_logger('lmdeploy') SeqList = List[SchedulerSequence] -_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64) - @dataclass class InferOutput: @@ -45,7 +37,7 @@ class InferOutput: session_id: int resp: Response - token_ids: List[int] + token_ids: Union[np.ndarray, List[int]] meta: Any = None finish: bool = False logits: torch.Tensor = None @@ -62,19 +54,6 @@ class InferOutput: routed_experts: torch.Tensor = None -def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): - """Tensorlize block_offsets.""" - # copy on numpy is faster than torch.nn.utils.rnn.pad_sequence - batch_size = len(block_offsets) - max_len = max([len(off) for off in block_offsets]) - out = np.zeros((batch_size, max_len), dtype=block_offsets[0].dtype) - - for idx, off in enumerate(block_offsets): - off_len = len(off) - out[idx, :off_len] = off - return torch.as_tensor(out, dtype=dtype) - - def _build_seq_meta(cache_config: CacheConfig, strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta @@ -82,177 +61,18 @@ def _build_seq_meta(cache_config: CacheConfig, strategy: Any): return seq_meta -class CounterEvent: - - def __init__(self): - self._counter = 0 - self._event = asyncio.Event() - - async def wait(self): - await self._event.wait() - - def is_set(self): - return self._event.is_set() - - def set(self): - if self._counter > 0: - self._counter -= 1 - if self._counter == 0: - self._event.set() - - def clear(self): - if self._counter == 0 and self._event.is_set(): - self._event.clear() - self._counter += 1 - - -class RunableEventBase: - """Runable event base.""" - - async def wait(self): - """Wait event.""" - raise NotImplementedError('Not implemented.') - - def set(self): - """Set event.""" - raise NotImplementedError('Not implemented.') - - -class RunableEventAsnyc(RunableEventBase): - """Awaitable async runable event.""" - - def __init__(self, scheduler: Scheduler): - self.scheduler = scheduler - self.event = asyncio.Event() - - async def wait(self): - """Wait event.""" - await self.event.wait() - - def set(self): - """Set event.""" - if self.scheduler.has_unfinished(): - self.event.set() - else: - self.event.clear() - - -def build_runable_event(scheduler: Scheduler): - """Build runable event.""" - return RunableEventAsnyc(scheduler) - - -class InputsMakerBase: - - def __init__(self, engine: 'Engine'): - self.engine = engine - self.scheduler_config = engine.scheduler_config - self.executor = engine.executor - - def _make_forward_inputs(self, *args, **kwargs): - """Make forward inputs.""" - return self.engine._make_forward_inputs(*args, **kwargs) - - async def send_next_inputs(self): - """Send next input.""" - raise NotImplementedError('Not implemented.') - - async def prefetch_next_inputs(self): - """prefetch.""" - raise NotImplementedError('Not implemented.') - - -class InputsMakerAsync(InputsMakerBase): - - def __init__(self, engine: 'Engine'): - super().__init__(engine) - self.scheduler = self.engine.scheduler - self.forward_inputs = None - self.spec_decoding = engine.specdecode_config is not None - - self.dp = self.engine.dist_config.dp - self.role = self.engine.cache_config.role - - self.next_is_prefill = True - if self.dp == 1: - self.do_prefill = self.do_prefill_default - else: - self.do_prefill = self.do_prefill_dp - - def do_prefill_dp(self): - if self.role == EngineRole.Prefill: - return True - - scheduler = self.scheduler - - if self.next_is_prefill: - ret = scheduler.has_waiting() - else: - ret = not scheduler.has_ready() - return ret - - def do_prefill_default(self): - # decoding if no waiting - scheduler = self.scheduler - if not scheduler.has_waiting(): - return False - num_ready = scheduler.num_ready() - num_waiting = scheduler.num_waiting() - max_batches = self.scheduler_config.max_batches - # prefill if too much waiting - permitted_waiting = 4 if (self.engine.engine_config.role != EngineRole.Prefill) else 1 - if num_waiting >= permitted_waiting: - return True - # prefill if no enough running - if num_ready < max_batches * 0.5: - return True - # decoding - return False - - async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool = False): - forward_inputs = self._make_forward_inputs(prefill, enable_empty) - if forward_inputs is None: - return None, None - next_running = forward_inputs.pop('running') - inputs = forward_inputs['inputs'] - logger.debug(f'Sending forward inputs: {inputs.log_info()}') - if logger.level <= logging.DEBUG: - session_ids = [seq.session_id for seq in next_running] - logger.debug(f'Forward session_ids: {session_ids}') - self.next_is_prefill = inputs.is_decoding - await self.executor.forward_async(forward_inputs) - self.forward_inputs = forward_inputs - return forward_inputs, next_running - - async def send_next_inputs(self): - prefill = self.do_prefill() - return await self._send_next_inputs_impl(prefill) - - async def prefetch_next_inputs(self): - enable = False - scheduler = self.scheduler - prefill = self.do_prefill() - if prefill: - enable = True - else: - num_ready = scheduler.num_ready() - is_decoding = self.forward_inputs['inputs'].is_decoding - running_threshold = (self.scheduler_config.max_batches // 4) if is_decoding or self.spec_decoding else 0 - - if num_ready > running_threshold: - enable = True - - if enable: - # send next forward - logger.debug('Prefetching next forward inputs.') - return await self._send_next_inputs_impl(prefill, True) - else: - return None, None - - -def build_inputs_maker(engine: 'Engine'): - """Build inputs makers.""" - return InputsMakerAsync(engine) +def response_reqs(req_manager: RequestManager, + resp: Response, + resp_type: ResponseType, + data: Any = None, + err_msg: str = ''): + """response.""" + if resp.type == ResponseType.FINISH: + return + resp.type = resp_type + resp.data = data + resp.err_msg = err_msg + req_manager.response(resp) class Engine(EngineBase): @@ -354,7 +174,7 @@ def __init__( self.req_manager = self._bind_request_manager() # create main thread - self._start_loop() + self.req_manager.set_main_loop(self.async_loop) self._loop_main = None # for PD Disaggregation @@ -433,18 +253,9 @@ def _bind_request_manager(self): req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message) return req_manager - def _start_loop(self): - """Start loop.""" - return self.req_manager.start_loop(self.async_loop) - def _response(self, resp: Response, resp_type: ResponseType, data: Any = None, err_msg: str = ''): """response.""" - if resp.type == ResponseType.FINISH: - return - resp.type = resp_type - resp.data = data - resp.err_msg = err_msg - self.req_manager.response(resp) + return response_reqs(self.req_manager, resp, resp_type, data, err_msg) def _get_max_session_len(self): """Get max session len.""" @@ -602,397 +413,6 @@ def model_config(self) -> ModelConfig: """Model config.""" return self.executor.model_config - @property - def torch_int_dtype(self): - """Return int32 for cuda, int64 for others.""" - if self.executor.device_type == 'cuda': - return torch.int32 - return torch.int64 - - def _create_vision_model_inputs(self, messages: SeqList, model_inputs: ModelInputs): - """Create vision model inputs.""" - batch_size = len(messages) - - def __get_vlm_embeddings(): - """Get vlm input embeddings and indexings.""" - max_q_seq_length = model_inputs.seq_length.max().item() - input_embeddings = [[ - emb.embeddings if isinstance(emb.embeddings, torch.Tensor) else torch.as_tensor(emb.embeddings) - for emb in msg.input_embeddings - ] for msg in messages] - input_embedding_ranges = [ - torch.tensor([[emb.start, emb.end] for emb in msg.input_embeddings]) for msg in messages - ] - input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool) - for msg_id, msg in enumerate(messages): - num_history_ids = msg.num_history_ids - for emb in msg.input_embeddings: - # make slice index relative to embeddings - emb_start = emb.start - num_history_ids - emb_end = emb.end - num_history_ids - input_embedding_indexing[msg_id][emb_start:emb_end] = True - return (input_embeddings, input_embedding_indexing, input_embedding_ranges) - - def __has_values(input_multimodals): - for input_mm in input_multimodals: - for val in input_mm.values(): - if len(val) > 0: - return True - return False - - has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages]) - if has_embedding: - has_embedding = any([len(msg.input_embeddings) > 0 for msg in messages]) - - has_multimodal = any([not msg.history_multimodals.empty() for msg in messages]) - input_multimodals = None - if has_multimodal: - input_multimodals = [msg.get_input_multimodals() for msg in messages] - has_multimodal = __has_values(input_multimodals) - if not has_multimodal: - # no multimodal inputs - input_multimodals = None - - if not has_embedding and not has_multimodal: - # no vision inputs - return None - - if has_embedding: - # for inputs with embeddings - (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings() - else: - input_embeddings = None - input_embedding_indexing = None - input_embedding_ranges = None - - history_lengths = model_inputs.history_lengths - vision_embedding_inputs = VisionModelInputs(history_lengths=history_lengths, - input_embeddings=input_embeddings, - input_embedding_indexing=input_embedding_indexing, - input_embedding_ranges=input_embedding_ranges, - input_multimodals=input_multimodals) - return vision_embedding_inputs - - @torch.inference_mode() - @record_function('create_model_inputs') - def create_model_inputs(self, messages: SeqList, is_prefill: bool): - """Create model inputs from messages. - - Args: - messages (SeqList): The input messages. - """ - batch_size = len(messages) - # history lengths - history_lengths = torch.tensor([msg.num_history_ids for msg in messages]) - - # input ids - token_ids = [msg.token_ids for msg in messages] - - input_ids = torch.as_tensor(np.concatenate(token_ids))[None] - - # seqlens - is_decoding = not is_prefill - if not is_decoding: - seq_length = [len(tokens) for tokens in token_ids] - seq_length = torch.tensor(seq_length, dtype=torch.long) - max_q_seqlen = seq_length.max().item() - else: - max_q_seqlen = len(token_ids[0]) - seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long) - kv_seqlens = seq_length + history_lengths - max_kv_seqlen = kv_seqlens.max().item() - sum_kv_seqlen = kv_seqlens.sum().item() - - # block offsets - block_offsets = self.scheduler.get_block_tables(messages) - block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype) - - # num_ignored_history - num_ignored_history = torch.tensor([msg.num_ignored_history for msg in messages]) - - # model_metas - model_metas = [msg.model_meta for msg in messages] - - # create model inputs for all required fields - model_inputs = ModelInputs( - input_ids=input_ids, - seq_length=seq_length, - history_lengths=history_lengths, - block_offsets=block_offsets, - is_decoding=is_decoding, - num_ignored_history=num_ignored_history, - max_q_seqlen=max_q_seqlen, - max_kv_seqlen=max_kv_seqlen, - sum_kv_seqlen=sum_kv_seqlen, - model_metas=model_metas, - ) - - # adapters - local_adapter_ids = None - if self.adapter_manager.num_adapters() > 1: - adapter_names = [msg.adapter_name for msg in messages] - local_adapter_ids = self.adapter_manager.get_adapter_ids(adapter_names) - local_adapter_ids = seq_length.new_tensor(local_adapter_ids) - model_inputs.local_adapter_ids = local_adapter_ids - - # cross for mllama - cross_length = torch.tensor([msg.num_cross for msg in messages]) - history_cross_length = torch.tensor([msg.num_history_cross for msg in messages]) - if (cross_length + history_cross_length).max().item() > 0: - model_inputs.cross_length = cross_length - model_inputs.history_cross_length = history_cross_length - - # vision inputs - vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs) - model_inputs.vision_inputs = vision_model_inputs - - # ssm - if len(self.cache_config.states_shapes) > 0: - state_offsets = torch.tensor([msg.logical_state for msg in messages]) - model_inputs.state_offsets = state_offsets - - return model_inputs - - def update_running_migration(self, running: SeqList, next_token_ids: np.ndarray, stopped: torch.Tensor, - model_metas: List[Dict[str, Any]]): - """Update scheduler.""" - if model_metas is None: - model_metas = [None] * len(running) - for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): - if msg.status != MessageStatus.MIGRATION_RUNNING: - continue - update_token = token - - # fill token - msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) - if stop: - update_token = _EMPTY_TOKEN - msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) - msg.state.finish() - - @record_function('make_infer_outputs') - def _make_infer_outputs( - self, - batched_outputs: BatchedOutputs, - running: SeqList, - is_decoding: bool, - ): - """Make infer output.""" - new_token_timestamp = batched_outputs.new_token_timestamp - logits = batched_outputs.logits - logprobs = batched_outputs.logprobs - - if logprobs is not None: - logprobs.vals = logprobs.vals.tolist() - logprobs.indices = logprobs.indices.tolist() - - seq_length = [seq.num_token_ids for seq in running] - is_run = [seq.status == MessageStatus.RUNNING for seq in running] - self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding) - - # generate output - outputs: Dict[int, InferOutput] = dict() - for idx, msg in enumerate(running): - if not is_run[idx]: - continue - token_ids = msg.generated_ids - finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED - if not finish and len(token_ids) == 0: - continue - resp_data = msg.resp.data - if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids): - # no new tokens - continue - session_id = msg.session_id - if msg.resp_cache: - cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist() - else: - cache_block_ids = None - - # logprobs - num_logprobs = msg.sampling_param.num_logprobs - cur_logprobs = None - if logprobs is not None: - cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1]) - # get spec stats info - spec_info = None - if self.specdecode_config is not None and is_decoding and self.engine_config.enable_metrics: - num_draft_tokens = self.specdecode_config.num_speculative_tokens - num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1 - spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens) - req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info) - routed_experts = msg.routed_experts if msg.return_routed_experts and finish else None - if routed_experts is not None and self.engine_config.enable_transfer_obj_ref: - # only serialize for api server - routed_experts = self.executor.serialize(routed_experts) - out = InferOutput(session_id=session_id, - resp=msg.resp, - finish=finish, - token_ids=token_ids, - cache_block_ids=cache_block_ids, - req_metrics=req_metrics, - logprobs=cur_logprobs, - routed_experts=routed_experts) - outputs[session_id] = out - - if msg.return_logits: - outputs[session_id].logits = logits.split(seq_length)[idx] - return outputs - - @record_function('make_forward_inputs') - def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): - """Make forward inputs.""" - - def __need_logits(seqs: SeqList): - """Need logits.""" - if self.specdecode_config is not None: - return True - return any(seq.return_logits for seq in seqs) - - def __need_routed_experts(seqs: SeqList): - """Need routed experts.""" - return any(seq.return_routed_experts for seq in seqs) - - def __need_schedule_again(prefill: bool, scheduler_output): - """Need schedule again.""" - # only reschedule when prefill - if not prefill: - return False - # schedule decoding if no valid prefill reqs. - if len(scheduler_output.running) > 0: - return False - # disable decoding for prefill role - if (self.engine_config.role == EngineRole.Prefill): - return False - # disable decoding if no running reqs. - if not self.scheduler.has_ready(): - logger.warning('No running sequences for decoding scheduling after prefill scheduling.') - return False - return True - - scheduler = self.scheduler - logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') - - prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) - - if enable_empty and len(scheduler_output.running) == 0: - return None - - if __need_schedule_again(prefill, scheduler_output): - prefill = False - prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) - scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) - - num_loops = self.engine_strategy.get_num_loops(not prefill) - running = scheduler_output.running - swap_in_map = scheduler_output.swap_in_map - swap_out_map = scheduler_output.swap_out_map - - if len(running) == 0: - return None - - # create inputs - inputs = self.create_model_inputs(running, prefill) - sampling_inputs = self.sampling_strategy.make_sampling_inputs(running) - return_logits = __need_logits(running) - return_routed_experts = __need_routed_experts(running) - extra_inputs = self.model_agent_strategy.make_extra_inputs(running) - stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) - - sync_long_context = inputs.input_ids.numel() > self.cache_config.max_prefill_token_num - - return dict( - running=running, - inputs=inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map, - loop_count=num_loops, - sampling_inputs=sampling_inputs, - stopping_criteria=stopping_criteria, - return_logits=return_logits, - is_dummy=False, - sync_long_context=sync_long_context, - extra_inputs=extra_inputs, - return_routed_experts=return_routed_experts, - ) - - async def _await_forward_event(self, forward_event: asyncio.Event): - """Await forward event.""" - await forward_event.wait() - - @torch.inference_mode() - async def _async_loop_preprocess_message(self, forward_event: asyncio.Event, has_runable_event: RunableEventBase): - """Preprocess msg.""" - while True: - await self._await_forward_event(forward_event) - await self.req_manager.step() - has_runable_event.set() - - async def _async_loop_send_responses(self, que: asyncio.Queue, forward_event: asyncio.Event): - """Send responses.""" - - def __log_resps(outputs: List[InferOutput]): - """Log resps.""" - if logger.level <= logging.DEBUG: - session_ids = [out.session_id for out in outputs] - logger.debug(f'Response sessions: {session_ids}') - elif logger.level <= logging.INFO: - logger.debug(f'Response: num_outputs={len(outputs)}.') - - def __send_resp(out: InferOutput): - """Send response.""" - resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS) - logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None) - self._response(out.resp, - resp_type, - data=dict(token_ids=out.token_ids, - logits=out.logits, - cache_block_ids=out.cache_block_ids, - req_metrics=out.req_metrics, - routed_experts=out.routed_experts, - logprobs=logprobs)) - - def __update_logprobs(step_outputs: List[InferOutput]): - for out in step_outputs: - cur_logprobs = out.logprobs - if cur_logprobs is None: - continue - - if out.resp.data is None: - out.resp.data = dict() - out.resp.data.setdefault('logprobs', []) - - # logprobs to dict - vals = cur_logprobs[0] - indices = cur_logprobs[1] - cur_logprobs = dict(zip(indices, vals)) - logprobs = out.resp.data['logprobs'] - logprobs.append(cur_logprobs) - - def __send_resps(step_outputs: List[InferOutput]): - """Send response callback.""" - __log_resps(step_outputs) - __update_logprobs(step_outputs) - - is_done = set() - for out in reversed(step_outputs): - if out.session_id in is_done: - continue - is_done.add(out.session_id) - __send_resp(out) - - while True: - num_outs = que.qsize() - if num_outs > 0: - resps = [] - for _ in range(num_outs): - resps += que.get_nowait().values() - else: - resps = (await que.get()).values() - await self._await_forward_event(forward_event) - __send_resps(resps) - async def p2p_initialize(self, init_request: DistServeInitRequest): return await self.engine_conn.p2p_initialize(init_request) @@ -1002,152 +422,10 @@ def p2p_connect(self, conn_request: DistServeConnectionRequest): async def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): return self.engine_conn.p2p_drop_connect(drop_conn_request) - @torch.inference_mode() - async def _async_loop_migration(self, resp_que: asyncio.Queue, has_runable_event: asyncio.Event): - """Async loop migration.""" - while True: - migration_ready = self.scheduler._schedule_migration() - if not migration_ready and not self.scheduler.has_migration_waiting(): - await self.migration_event.wait() - elif migration_ready: - self.migration_event.clear() - for msg in migration_ready: - migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] - migration_request = msg.migration_request - prefill_block_ids = migration_request.remote_block_ids - decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) - - if not migration_request.is_dummy_prefill: - assert len(prefill_block_ids) == len(decode_block_ids), ( - f'#prefill block ids ({len(prefill_block_ids)}) must equal to ' - f'#decode block ids ({len(decode_block_ids)})' - f'all id length: {len(msg.num_token_ids)}') - migration_execution_requests.append(( - migration_request.remote_engine_id, - list(zip(prefill_block_ids, decode_block_ids)), - )) - migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, - requests=migration_execution_requests) - logger.info(f'migrating session: {msg.session_id} begin') - await self.executor.migrate(migration_inputs) - logger.info(f'migrating session: {msg.session_id} done') - await self.engine_conn.zmq_send(remote_engine_id=migration_request.remote_engine_id, - remote_session_id=migration_request.remote_session_id) - - # generate output - outputs: Dict[int, InferOutput] = dict() - self.scheduler.activate_migration_seqs(migration_ready) - for _, msg in enumerate(migration_ready): - session_id = msg.session_id - msg.resp.type = ResponseType.SUCCESS - token_ids = [msg.migration_request.remote_token_id] - # MUST be a wall-clock time - new_token_timestamp = time.time() - req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) - out = InferOutput( - session_id=session_id, - resp=msg.resp, - finish=False, - token_ids=np.array(token_ids), - req_metrics=req_metrics, - ) - outputs[session_id] = out - self.update_running_migration([msg], np.array([token_ids]), [False], [None]) - resp_que.put_nowait(outputs) - self.scheduler.deactivate_migration_seqs(migration_ready) - has_runable_event.set() - else: - # release coroutine for decoding - await asyncio.sleep(.5) - - @torch.inference_mode() - async def _async_loop_main( - self, - resp_que: asyncio.Queue, - forward_event: asyncio.Event, - has_runable_event: RunableEventBase, - inputs_maker: InputsMakerBase, - ): - """Main loop of the engine. - - Each engine instance would communicate with the engine by queue. - """ - scheduler = self.scheduler - forward_inputs = None - next_running = None - - while True: - if next_running is None: - if not scheduler.has_unfinished(): - forward_event.set() - await has_runable_event.wait() - forward_event.clear() - - scheduler.collect_migration_done() - forward_inputs, next_running = await inputs_maker.send_next_inputs() - if next_running is None: - # TODO (JimyMa): add watermark check event instead of async sleep. - # self.perfill_watermark_event.wait() - logger.warning(f'no next prefill running request, Maybe cache is full, ' - f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, ' - f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}') - forward_event.set() - await asyncio.sleep(0.1) - forward_event.clear() - continue - - forward_event.set() - num_loops = forward_inputs['loop_count'] - is_decoding = forward_inputs['inputs'].is_decoding - running = next_running - next_running = None - scheduler.active_seqs(running) - for idx in range(num_loops): - - # pre-forward before get last token - if idx == num_loops - 1: - scheduler.collect_migration_done() - forward_inputs, next_running = await inputs_maker.prefetch_next_inputs() - # send output - out = await self.executor.get_output_async() - if out is not None: - step_outputs = self._make_infer_outputs(out, running=running, is_decoding=is_decoding) - resp_que.put_nowait(step_outputs) - - # lock forward event - # make sure that prefetch forward would not wait for detokenize - # WARNING: this might have side effect on the performance - if idx == num_loops // 2: - forward_event.clear() - - scheduler.deactive_seqs(running) - has_runable_event.set() - - @staticmethod - def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]): - """Add loop tasks done callback.""" - - def __task_callback(task: asyncio.Task) -> None: - """Raise exception on finish.""" - task_name = task.get_name() - try: - task.result() - except asyncio.CancelledError: - logger.debug(f'Task <{task_name}> cancelled.') - return - except Exception: - logger.exception(f'Task <{task_name}> failed') - finally: - for task in tasks: - if not task.done(): - task.cancel() - - for task in tasks: - task.add_done_callback(__task_callback) - def _loop_finally(self): """Finally process for dist.""" logger.info('Cleanup executor.') + self.migration_event = None self.executor.stop() self.executor.release() @@ -1164,53 +442,25 @@ def wakeup(self, tags: Optional[List[str]] = None): self.executor.wakeup(tags) async def async_loop(self): + engine_loop = None try: + from lmdeploy.pytorch.engine.engine_loop import build_engine_loop + self._loop_main = asyncio.current_task() event_loop = asyncio.get_event_loop() - # forward task - forward_event = CounterEvent() - # migration task - self.migration_event = asyncio.Event() + engine_loop = build_engine_loop(self) + self.migration_event = engine_loop.migration_event + forward_event = engine_loop.forward_event logger.info('Starting executor.') self.executor.start(forward_event) - # preprocess task - logger.info('Starting async task MainLoopPreprocessMessage.') - has_runable_event = build_runable_event(self.scheduler) - loop_msg_proc = event_loop.create_task(self._async_loop_preprocess_message( - forward_event, has_runable_event), - name='MainLoopPreprocessMessage') - - # response task - logger.info('Starting async task MainLoopResponse.') - resp_que = asyncio.Queue() - loop_send_resp = event_loop.create_task(self._async_loop_send_responses(resp_que, forward_event), - name='MainLoopResponse') - - loop_main = asyncio.current_task() - loop_tasks: List[asyncio.Task] = [loop_main, loop_msg_proc, loop_send_resp] - - if self.engine_config.role != EngineRole.Hybrid: - logger.info('Starting async task MigrationLoop.') - loop_migration = event_loop.create_task( - self._async_loop_migration(resp_que, has_runable_event=has_runable_event), - name='MainLoopMigration', - ) - loop_tasks.append(loop_migration) - - # binding done callback - self._add_loop_tasks_done_callback(loop_tasks) - self._loop_main = loop_main - - # main loop - logger.info('Starting async task MainLoop.') - inputs_maker = build_inputs_maker(self) - await self._async_loop_main(resp_que=resp_que, - forward_event=forward_event, - has_runable_event=has_runable_event, - inputs_maker=inputs_maker) + # start engine loop + engine_loop.create_tasks(event_loop) + await engine_loop.wait_tasks() + except asyncio.CancelledError: + logger.debug('Engine main loop cancelled.') except Exception: logger.exception('Engine main loop failed.') finally: diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py new file mode 100644 index 0000000000..696c0d6433 --- /dev/null +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -0,0 +1,514 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import torch +from torch.profiler import record_function + +from lmdeploy.messages import RequestMetrics +from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.pytorch.disagg.messages import MigrationExecutionBatch +from lmdeploy.pytorch.messages import MessageStatus, UpdateTokenMode +from lmdeploy.utils import get_logger + +from .engine import InferOutput, ResponseType, response_reqs + +if TYPE_CHECKING: + from lmdeploy.pytorch.disagg.conn.engine_conn import EngineP2PConnection + from lmdeploy.pytorch.engine.model_agent import BatchedOutputs + from lmdeploy.pytorch.paging import Scheduler + from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy + + from .engine import Engine, SeqList + from .executor import ExecutorBase + from .input_maker import InputsMakerAsync + from .request import RequestManager + +logger = get_logger('lmdeploy') +_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64) + + +class CounterEvent: + + def __init__(self): + self._counter = 0 + self._event = asyncio.Event() + + async def wait(self): + await self._event.wait() + + def is_set(self): + return self._event.is_set() + + def set(self): + if self._counter > 0: + self._counter -= 1 + if self._counter == 0: + self._event.set() + + def clear(self): + if self._counter == 0 and self._event.is_set(): + self._event.clear() + self._counter += 1 + + +class RunableEventAsync: + """Awaitable async runable event.""" + + def __init__(self, scheduler: 'Scheduler'): + self.scheduler = scheduler + self.event = asyncio.Event() + + async def wait(self): + """Wait event.""" + await self.event.wait() + + def set(self): + """Set event.""" + if self.scheduler.has_unfinished(): + self.event.set() + else: + self.event.clear() + + +def build_runable_event(scheduler: 'Scheduler'): + """Build runable event.""" + return RunableEventAsync(scheduler) + + +@dataclass +class EngineLoopConfig: + """Engine loop config. + + This config is added for Dependency Injection + """ + role: EngineRole + num_speculative_tokens: Optional[int] = None + enable_metrics: bool = False + enable_transfer_obj_ref: bool = False + + @staticmethod + def from_engine(engine: 'Engine'): + """Create engine loop config from engine.""" + if engine.specdecode_config is None: + num_speculative_tokens = None + else: + num_speculative_tokens = engine.specdecode_config.num_speculative_tokens + + return EngineLoopConfig( + role=engine.engine_config.role, + num_speculative_tokens=num_speculative_tokens, + enable_metrics=engine.engine_config.enable_metrics, + enable_transfer_obj_ref=engine.engine_config.enable_transfer_obj_ref, + ) + + +class EngineLoop: + """Engine loop manager should be created in an async context.""" + + def __init__(self, + req_manager: 'RequestManager', + scheduler: 'Scheduler', + executor: 'ExecutorBase', + seq_strategy: 'SequenceStrategy', + inputs_maker: 'InputsMakerAsync', + config: EngineLoopConfig, + engine_conn: Optional['EngineP2PConnection'] = None): + self.req_manager = req_manager + self.scheduler = scheduler + self.executor = executor + self.seq_strategy = seq_strategy + self.inputs_maker = inputs_maker + self.config = config + self.engine_conn = engine_conn + + # tasks and control events + self.tasks: Set[asyncio.Task] = set() + self.stop_event = asyncio.Event() + self.resp_queue = asyncio.Queue() + self.forward_event = CounterEvent() + self.migration_event = asyncio.Event() + self.has_runable_event = RunableEventAsync(self.scheduler) + + # check init + if self.config.role != EngineRole.Hybrid: + assert self.engine_conn is not None, 'Engine connection must be provided for non-hybrid engine role.' + + async def preprocess_loop(self): + """Preprocess request.""" + while not self.stop_event.is_set(): + await self.forward_event.wait() + await self.req_manager.step() + self.has_runable_event.set() + + @staticmethod + def _log_resps(outputs: List[InferOutput]): + """Log resps.""" + if logger.level <= logging.DEBUG: + session_ids = [out.session_id for out in outputs] + logger.debug(f'Response sessions: {session_ids}') + elif logger.level <= logging.INFO: + logger.info(f'Response: num_outputs={len(outputs)}.') + + def _send_resp(self, out: InferOutput): + """Send response.""" + resp_type = (ResponseType.FINISH if out.finish else ResponseType.SUCCESS) + logprobs = None if out.resp.data is None else out.resp.data.get('logprobs', None) + response_reqs(self.req_manager, + out.resp, + resp_type, + data=dict(token_ids=out.token_ids, + logits=out.logits, + cache_block_ids=out.cache_block_ids, + req_metrics=out.req_metrics, + routed_experts=out.routed_experts, + logprobs=logprobs)) + + @staticmethod + def _update_logprobs(step_outputs: List[InferOutput]): + for out in step_outputs: + cur_logprobs = out.logprobs + if cur_logprobs is None: + continue + + if out.resp.data is None: + out.resp.data = dict() + out.resp.data.setdefault('logprobs', []) + + # logprobs to dict + vals = cur_logprobs[0] + indices = cur_logprobs[1] + cur_logprobs = dict(zip(indices, vals)) + logprobs = out.resp.data['logprobs'] + logprobs.append(cur_logprobs) + + def _send_resps(self, step_outputs: List[InferOutput]): + """Send response callback.""" + self._log_resps(step_outputs) + self._update_logprobs(step_outputs) + + is_done = set() + for out in reversed(step_outputs): + if out.session_id in is_done: + continue + is_done.add(out.session_id) + self._send_resp(out) + + async def send_response_loop(self): + """Send response to client.""" + que = self.resp_queue + while not self.stop_event.is_set(): + num_outs = que.qsize() + if num_outs > 0: + resps = [] + for _ in range(num_outs): + resps += que.get_nowait().values() + else: + resps = (await que.get()).values() + await self.forward_event.wait() + self._send_resps(resps) + + @record_function('make_infer_outputs') + def _make_infer_outputs( + self, + batched_outputs: 'BatchedOutputs', + running: 'SeqList', + is_decoding: bool, + ): + """Make infer output.""" + new_token_timestamp = batched_outputs.new_token_timestamp + logits = batched_outputs.logits + logprobs = batched_outputs.logprobs + + if logprobs is not None: + logprobs.vals = logprobs.vals.tolist() + logprobs.indices = logprobs.indices.tolist() + + seq_length = [seq.num_token_ids for seq in running] + is_run = [seq.status == MessageStatus.RUNNING for seq in running] + self.seq_strategy.update_running(running=running, batched_outputs=batched_outputs, is_decoding=is_decoding) + + # generate output + outputs: Dict[int, InferOutput] = dict() + for idx, msg in enumerate(running): + if not is_run[idx]: + continue + token_ids = msg.generated_ids + finish = msg.status == MessageStatus.STOPPED or msg.status == MessageStatus.TO_BE_MIGRATED + if not finish and len(token_ids) == 0: + continue + resp_data = msg.resp.data + if resp_data is not None and len(resp_data.get('token_ids', [])) == len(token_ids): + # no new tokens + continue + session_id = msg.session_id + if msg.resp_cache: + cache_block_ids = self.scheduler.block_manager.get_block_table(msg).tolist() + else: + cache_block_ids = None + + # logprobs + num_logprobs = msg.sampling_param.num_logprobs + cur_logprobs = None + if logprobs is not None: + cur_logprobs = (logprobs.vals[idx][:num_logprobs + 1], logprobs.indices[idx][:num_logprobs + 1]) + # get spec stats info + spec_info = None + num_draft_tokens = self.config.num_speculative_tokens + if num_draft_tokens is not None and is_decoding and self.config.enable_metrics: + num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1 + spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens) + req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info) + routed_experts = msg.routed_experts if msg.return_routed_experts and finish else None + if routed_experts is not None and self.config.enable_transfer_obj_ref: + # only serialize for api server + routed_experts = self.executor.serialize(routed_experts) + out = InferOutput(session_id=session_id, + resp=msg.resp, + finish=finish, + token_ids=token_ids, + cache_block_ids=cache_block_ids, + req_metrics=req_metrics, + logprobs=cur_logprobs, + routed_experts=routed_experts) + outputs[session_id] = out + + if msg.return_logits: + outputs[session_id].logits = logits.split(seq_length)[idx] + return outputs + + async def _main_loop_try_send_next_inputs(self): + """Try send next inputs.""" + scheduler = self.scheduler + if not scheduler.has_unfinished(): + self.forward_event.set() + await self.has_runable_event.wait() + self.forward_event.clear() + + scheduler.collect_migration_done() + return await self.inputs_maker.send_next_inputs() + + async def _main_loop_get_outputs( + self, + running: 'SeqList', + forward_inputs: Dict[str, Any], + ): + """Get outputs and prefetch.""" + num_loops = forward_inputs['loop_count'] + is_decoding = forward_inputs['inputs'].is_decoding + for idx in range(num_loops): + # pre-forward before get last token + if idx == num_loops - 1: + self.scheduler.collect_migration_done() + forward_inputs, next_running = await self.inputs_maker.prefetch_next_inputs() + + # send output + out = await self.executor.get_output_async() + if out is not None: + step_outputs = self._make_infer_outputs(out, running=running, is_decoding=is_decoding) + self.resp_queue.put_nowait(step_outputs) + + # lock forward event + # make sure that prefetch forward would not wait for detokenize + # WARNING: this might have side effect on the performance + if idx == num_loops // 2: + self.forward_event.clear() + return forward_inputs, next_running + + async def main_loop(self): + """Main loop of the engine. + + Each engine instance would communicate with the engine by queue. + """ + forward_event = self.forward_event + has_runable_event = self.has_runable_event + scheduler = self.scheduler + forward_inputs = None + next_running = None + + async def __no_running_warning(): + # TODO (JimyMa): add watermark check event instead of async sleep. + # self.perfill_watermark_event.wait() + logger.warning(f'no next prefill running request, Maybe cache is full, ' + f'free gpu cache blocks: {scheduler.block_manager.get_num_free_gpu_blocks()}, ' + f'total gpu cache blocks: {scheduler.block_manager.num_gpu_blocks}') + forward_event.set() + await asyncio.sleep(0.1) + forward_event.clear() + + while not self.stop_event.is_set(): + if next_running is None: + forward_inputs, next_running = await self._main_loop_try_send_next_inputs() + if next_running is None: + await __no_running_warning() + continue + + forward_event.set() + with scheduler.seqs_activation(next_running): + forward_inputs, next_running = await self._main_loop_get_outputs( + running=next_running, + forward_inputs=forward_inputs, + ) + has_runable_event.set() + + def update_running_migration(self, running: 'SeqList', next_token_ids: np.ndarray, stopped: torch.Tensor, + model_metas: List[Dict[str, Any]]): + """Update scheduler.""" + if model_metas is None: + model_metas = [None] * len(running) + for token, msg, stop, model_meta in zip(next_token_ids, running, stopped, model_metas): + if msg.status != MessageStatus.MIGRATION_RUNNING: + continue + update_token = token + + # fill token + msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) + if stop: + update_token = _EMPTY_TOKEN + msg.update_token_ids(update_token, model_meta=model_meta, mode=UpdateTokenMode.PREFILL) + msg.state.finish() + + async def _migration_loop_migrate(self, migration_ready: 'SeqList'): + """Migration loop migrate.""" + for msg in migration_ready: + # skip dummy prefill migration + if msg.migration_request.is_dummy_prefill: + continue + + migration_execution_requests: List[Tuple[int, List[Tuple[int, int]]]] = [] + migration_request = msg.migration_request + prefill_block_ids = migration_request.remote_block_ids + decode_block_ids = list(self.scheduler.block_manager.get_block_table(msg=msg)) + + assert len(prefill_block_ids) == len(decode_block_ids), ( + f'#prefill block ids ({len(prefill_block_ids)}) must equal to ' + f'#decode block ids ({len(decode_block_ids)})' + f'all id length: {len(msg.num_token_ids)}') + migration_execution_requests.append(( + migration_request.remote_engine_id, + list(zip(prefill_block_ids, decode_block_ids)), + )) + migration_inputs = MigrationExecutionBatch(protocol=migration_request.protocol, + requests=migration_execution_requests) + logger.info(f'migrating session: {msg.session_id} begin') + await self.executor.migrate(migration_inputs) + logger.info(f'migrating session: {msg.session_id} done') + await self.engine_conn.zmq_send(remote_engine_id=migration_request.remote_engine_id, + remote_session_id=migration_request.remote_session_id) + + async def _migration_loop_get_outputs(self, migration_ready: 'SeqList'): + """Migration loop get outputs.""" + outputs: Dict[int, InferOutput] = dict() + for _, msg in enumerate(migration_ready): + session_id = msg.session_id + msg.resp.type = ResponseType.SUCCESS + token_ids = [msg.migration_request.remote_token_id] + # MUST be a wall-clock time + new_token_timestamp = time.time() + req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events) + out = InferOutput( + session_id=session_id, + resp=msg.resp, + finish=False, + token_ids=np.array(token_ids), + req_metrics=req_metrics, + ) + outputs[session_id] = out + self.update_running_migration([msg], np.array([token_ids]), [False], [None]) + self.resp_queue.put_nowait(outputs) + + async def _migration_loop_process_ready(self, migration_ready: 'SeqList'): + """Process migration ready.""" + await self._migration_loop_migrate(migration_ready) + + # generate output + with self.scheduler.seqs_migration_activation(migration_ready): + await self._migration_loop_get_outputs(migration_ready) + self.has_runable_event.set() + + async def migration_loop(self): + """Async loop migration.""" + while not self.stop_event.is_set(): + migration_ready = self.scheduler._schedule_migration() + if not migration_ready and not self.scheduler.has_migration_waiting(): + await self.migration_event.wait() + elif migration_ready: + self.migration_event.clear() + await self._migration_loop_process_ready(migration_ready) + else: + # release coroutine for decoding + await asyncio.sleep(.5) + + def _add_loop_tasks_done_callback(self): + """Add loop tasks done callback.""" + + def __task_callback(task: asyncio.Task) -> None: + """Raise exception on finish.""" + task_name = task.get_name() + try: + task.result() + except asyncio.CancelledError: + logger.debug(f'Task <{task_name}> cancelled.') + except Exception: + logger.exception(f'Task <{task_name}> failed') + finally: + self.stop_event.set() + self.cancel() + + for task in self.tasks: + task.add_done_callback(__task_callback) + + def create_tasks(self, event_loop: asyncio.AbstractEventLoop): + """Create async tasks.""" + logger.info('Starting async task MainLoopPreprocessMessage.') + self.tasks.add(event_loop.create_task(self.preprocess_loop(), name='MainLoopPreprocessMessage')) + logger.info('Starting async task MainLoopResponse.') + self.tasks.add(event_loop.create_task(self.send_response_loop(), name='MainLoopSendResponse')) + logger.info('Starting async task MainLoop.') + self.tasks.add(event_loop.create_task(self.main_loop(), name='MainLoopMain')) + if self.config.role != EngineRole.Hybrid: + logger.info('Starting async task MigrationLoop.') + self.tasks.add(event_loop.create_task(self.migration_loop(), name='MainLoopMigration')) + + self._add_loop_tasks_done_callback() + + async def wait_tasks(self, timeout: Optional[float] = None): + """Wait for all tasks to finish.""" + if len(self.tasks) == 0: + return + if timeout is not None: + await asyncio.wait(asyncio.gather(*self.tasks), timeout=timeout) + else: + await asyncio.gather(*self.tasks) + self.stop() + + def stop(self): + """Stop all loops.""" + self.stop_event.set() + + def cancel(self): + """Cancel all loops.""" + for task in self.tasks: + if not (task.done() or task.cancelled()): + task.cancel() + + +def build_engine_loop(engine: 'Engine'): + """Build engine loop.""" + from .input_maker import build_inputs_maker + + config = EngineLoopConfig.from_engine(engine) + inputs_maker = build_inputs_maker(engine) + return EngineLoop( + req_manager=engine.req_manager, + scheduler=engine.scheduler, + executor=engine.executor, + seq_strategy=engine.seq_strategy, + inputs_maker=inputs_maker, + config=config, + engine_conn=engine.engine_conn, + ) diff --git a/lmdeploy/pytorch/engine/input_maker.py b/lmdeploy/pytorch/engine/input_maker.py new file mode 100644 index 0000000000..e0167c3825 --- /dev/null +++ b/lmdeploy/pytorch/engine/input_maker.py @@ -0,0 +1,415 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import torch +from torch.profiler import record_function + +from lmdeploy.pytorch.disagg.config import EngineRole +from lmdeploy.pytorch.model_inputs import ModelInputs, VisionModelInputs +from lmdeploy.utils import get_logger + +if TYPE_CHECKING: + from lmdeploy.pytorch.adapter.adapter import AdapterManager + from lmdeploy.pytorch.paging import Scheduler + from lmdeploy.pytorch.strategies.base.engine import EngineStrategy + from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy + + from .engine import Engine, SeqList + from .executor import ExecutorBase + +logger = get_logger('lmdeploy') + + +def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): + """Tensorlize block_offsets.""" + # copy on numpy is faster than torch.nn.utils.rnn.pad_sequence + batch_size = len(block_offsets) + max_len = max([len(off) for off in block_offsets]) + out = np.zeros((batch_size, max_len), dtype=block_offsets[0].dtype) + + for idx, off in enumerate(block_offsets): + off_len = len(off) + out[idx, :off_len] = off + return torch.as_tensor(out, dtype=dtype) + + +@dataclass +class InputMakerConfig: + """Input maker config. + + This config is added for Dependency Injection + """ + max_batches: int + max_prefill_token_num: int + role: EngineRole + is_ssm: bool = False + dp: int = 1 + spec_decoding: bool = False + + @staticmethod + def from_engine(engine: 'Engine'): + cache_config = engine.cache_config + return InputMakerConfig( + spec_decoding=engine.specdecode_config is not None, + max_batches=cache_config.max_batches, + max_prefill_token_num=cache_config.max_prefill_token_num, + role=cache_config.role, + is_ssm=len(cache_config.states_shapes) > 0, + dp=engine.dist_config.dp, + ) + + +class InputsMakerAsync: + + def __init__( + self, + executor: 'ExecutorBase', + scheduler: 'Scheduler', + adapter_manager: 'AdapterManager', + engine_strategy: 'EngineStrategy', + sampling_strategy: 'SamplingStrategy', + model_agent_strategy: 'ModelAgentStrategy', + config: InputMakerConfig, + ): + self.executor = executor + self.scheduler = scheduler + self.adapter_manager = adapter_manager + self.config = config + self.spec_decoding = config.spec_decoding + + # strategies + self.engine_strategy = engine_strategy + self.sampling_strategy = sampling_strategy + self.model_agent_strategy = model_agent_strategy + + self._init_do_prefill(config) + + # record for next forward. + self.next_is_prefill = True + self.forward_inputs = None + + def _init_do_prefill(self, config: InputMakerConfig): + if config.role == EngineRole.Prefill: + self.do_prefill = self.do_prefill_pnode + elif config.dp == 1: + self.do_prefill = self.do_prefill_default + else: + self.do_prefill = self.do_prefill_dp + + def _create_vision_model_inputs(self, messages: 'SeqList', model_inputs: ModelInputs): + """Create vision model inputs.""" + batch_size = len(messages) + + def __get_vlm_embeddings(): + """Get vlm input embeddings and indexings.""" + max_q_seq_length = model_inputs.seq_length.max().item() + input_embeddings = [[ + emb.embeddings if isinstance(emb.embeddings, torch.Tensor) else torch.as_tensor(emb.embeddings) + for emb in msg.input_embeddings + ] for msg in messages] + input_embedding_ranges = [ + torch.tensor([[emb.start, emb.end] for emb in msg.input_embeddings]) for msg in messages + ] + input_embedding_indexing = torch.zeros((batch_size, max_q_seq_length), dtype=torch.bool) + for msg_id, msg in enumerate(messages): + num_history_ids = msg.num_history_ids + for emb in msg.input_embeddings: + # make slice index relative to embeddings + emb_start = emb.start - num_history_ids + emb_end = emb.end - num_history_ids + input_embedding_indexing[msg_id][emb_start:emb_end] = True + return (input_embeddings, input_embedding_indexing, input_embedding_ranges) + + def __has_values(input_multimodals): + for input_mm in input_multimodals: + for val in input_mm.values(): + if len(val) > 0: + return True + return False + + has_embedding = any([len(msg.history_embeddings) > 0 for msg in messages]) + if has_embedding: + has_embedding = any([len(msg.input_embeddings) > 0 for msg in messages]) + + has_multimodal = any([not msg.history_multimodals.empty() for msg in messages]) + input_multimodals = None + if has_multimodal: + input_multimodals = [msg.get_input_multimodals() for msg in messages] + has_multimodal = __has_values(input_multimodals) + if not has_multimodal: + # no multimodal inputs + input_multimodals = None + + if not has_embedding and not has_multimodal: + # no vision inputs + return None + + if has_embedding: + # for inputs with embeddings + (input_embeddings, input_embedding_indexing, input_embedding_ranges) = __get_vlm_embeddings() + else: + input_embeddings = None + input_embedding_indexing = None + input_embedding_ranges = None + + history_lengths = model_inputs.history_lengths + vision_embedding_inputs = VisionModelInputs(history_lengths=history_lengths, + input_embeddings=input_embeddings, + input_embedding_indexing=input_embedding_indexing, + input_embedding_ranges=input_embedding_ranges, + input_multimodals=input_multimodals) + return vision_embedding_inputs + + @property + def torch_int_dtype(self): + """Return int32 for cuda, int64 for others.""" + if self.executor.device_type == 'cuda': + return torch.int32 + return torch.int64 + + @torch.inference_mode() + @record_function('create_model_inputs') + def create_model_inputs(self, messages: 'SeqList', is_prefill: bool): + """Create model inputs from messages. + + Args: + messages (SeqList): The input messages. + """ + batch_size = len(messages) + # history lengths + history_lengths = torch.tensor([msg.num_history_ids for msg in messages]) + + # input ids + token_ids = [msg.token_ids for msg in messages] + + input_ids = torch.as_tensor(np.concatenate(token_ids))[None] + + # seqlens + is_decoding = not is_prefill + if not is_decoding: + seq_length = [len(tokens) for tokens in token_ids] + seq_length = torch.tensor(seq_length, dtype=torch.long) + max_q_seqlen = seq_length.max().item() + else: + max_q_seqlen = len(token_ids[0]) + seq_length = torch.full((batch_size, ), max_q_seqlen, dtype=torch.long) + kv_seqlens = seq_length + history_lengths + max_kv_seqlen = kv_seqlens.max().item() + sum_kv_seqlen = kv_seqlens.sum().item() + + # block offsets + block_offsets = self.scheduler.get_block_tables(messages) + block_offsets = _tensorlize_block_offsets(block_offsets, dtype=self.torch_int_dtype) + + # num_ignored_history + num_ignored_history = torch.tensor([msg.num_ignored_history for msg in messages]) + + # model_metas + model_metas = [msg.model_meta for msg in messages] + + # create model inputs for all required fields + model_inputs = ModelInputs( + input_ids=input_ids, + seq_length=seq_length, + history_lengths=history_lengths, + block_offsets=block_offsets, + is_decoding=is_decoding, + num_ignored_history=num_ignored_history, + max_q_seqlen=max_q_seqlen, + max_kv_seqlen=max_kv_seqlen, + sum_kv_seqlen=sum_kv_seqlen, + model_metas=model_metas, + ) + + # adapters + local_adapter_ids = None + if self.adapter_manager.num_adapters() > 1: + adapter_names = [msg.adapter_name for msg in messages] + local_adapter_ids = self.adapter_manager.get_adapter_ids(adapter_names) + local_adapter_ids = seq_length.new_tensor(local_adapter_ids) + model_inputs.local_adapter_ids = local_adapter_ids + + # cross for mllama + cross_length = torch.tensor([msg.num_cross for msg in messages]) + history_cross_length = torch.tensor([msg.num_history_cross for msg in messages]) + if (cross_length + history_cross_length).max().item() > 0: + model_inputs.cross_length = cross_length + model_inputs.history_cross_length = history_cross_length + + # vision inputs + vision_model_inputs = self._create_vision_model_inputs(messages, model_inputs) + model_inputs.vision_inputs = vision_model_inputs + + # ssm + if self.config.is_ssm: + state_offsets = torch.tensor([msg.logical_state for msg in messages]) + model_inputs.state_offsets = state_offsets + + return model_inputs + + @torch.inference_mode() + @record_function('make_forward_inputs') + def _make_forward_inputs(self, prefill: bool, enable_empty: bool = False): + """Make forward inputs for ModelAgent._async_step_background()""" + + def __need_logits(seqs: 'SeqList'): + """Need logits.""" + if self.spec_decoding: + return True + return any(seq.return_logits for seq in seqs) + + def __need_routed_experts(seqs: 'SeqList'): + """Need routed experts.""" + return any(seq.return_routed_experts for seq in seqs) + + def __need_schedule_again(prefill: bool, scheduler_output): + """Need schedule again.""" + # only reschedule when prefill + if not prefill: + return False + # schedule decoding if no valid prefill reqs. + if len(scheduler_output.running) > 0: + return False + # disable decoding for prefill role + if (self.config.role == EngineRole.Prefill): + return False + # disable decoding if no running reqs. + if not self.scheduler.has_ready(): + logger.warning('No running sequences for decoding scheduling after prefill scheduling.') + return False + return True + + scheduler = self.scheduler + logger.debug(f'Make forward inputs with prefill={prefill}, enable_empty={enable_empty}') + + prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) + + if enable_empty and len(scheduler_output.running) == 0: + return None + + if __need_schedule_again(prefill, scheduler_output): + prefill = False + prealloc_size = self.engine_strategy.get_prealloc_size(not prefill) + scheduler_output = scheduler.schedule(is_prefill=prefill, prealloc_size=prealloc_size) + + num_loops = self.engine_strategy.get_num_loops(not prefill) + running = scheduler_output.running + swap_in_map = scheduler_output.swap_in_map + swap_out_map = scheduler_output.swap_out_map + + if len(running) == 0: + return None + + # create inputs + inputs = self.create_model_inputs(running, prefill) + sampling_inputs = self.sampling_strategy.make_sampling_inputs(running) + return_logits = __need_logits(running) + return_routed_experts = __need_routed_experts(running) + extra_inputs = self.model_agent_strategy.make_extra_inputs(running) + stopping_criteria = self.model_agent_strategy.make_stopping_criteria(running) + + sync_long_context = inputs.input_ids.numel() > self.config.max_prefill_token_num + + return dict( + running=running, + inputs=inputs, + swap_in_map=swap_in_map, + swap_out_map=swap_out_map, + loop_count=num_loops, + sampling_inputs=sampling_inputs, + stopping_criteria=stopping_criteria, + return_logits=return_logits, + is_dummy=False, + sync_long_context=sync_long_context, + extra_inputs=extra_inputs, + return_routed_experts=return_routed_experts, + ) + + def do_prefill_pnode(self): + return True + + def do_prefill_dp(self): + scheduler = self.scheduler + + if self.next_is_prefill: + ret = scheduler.has_waiting() + else: + ret = not scheduler.has_ready() + return ret + + def do_prefill_default(self): + # decoding if no waiting + scheduler = self.scheduler + if not scheduler.has_waiting(): + return False + num_ready = scheduler.num_ready() + num_waiting = scheduler.num_waiting() + max_batches = self.config.max_batches + # prefill if too much waiting + permitted_waiting = 4 if (self.config.role != EngineRole.Prefill) else 1 + if num_waiting >= permitted_waiting: + return True + # prefill if no enough running + if num_ready < max_batches * 0.5: + return True + # decoding + return False + + async def _send_next_inputs_impl(self, prefill: bool = None, enable_empty: bool = False): + forward_inputs = self._make_forward_inputs(prefill, enable_empty) + if forward_inputs is None: + return None, None + next_running = forward_inputs.pop('running') + inputs = forward_inputs['inputs'] + logger.debug(f'Sending forward inputs: {inputs.log_info()}') + if logger.level <= logging.DEBUG: + session_ids = [seq.session_id for seq in next_running] + logger.debug(f'Forward session_ids: {session_ids}') + self.next_is_prefill = inputs.is_decoding + await self.executor.forward_async(forward_inputs) + self.forward_inputs = forward_inputs + return forward_inputs, next_running + + async def send_next_inputs(self): + prefill = self.do_prefill() + return await self._send_next_inputs_impl(prefill) + + async def prefetch_next_inputs(self): + enable = False + scheduler = self.scheduler + prefill = self.do_prefill() + if prefill: + enable = True + else: + num_ready = scheduler.num_ready() + is_decoding = self.forward_inputs['inputs'].is_decoding + running_threshold = (self.config.max_batches // 4) if is_decoding or self.spec_decoding else 0 + + if num_ready > running_threshold: + enable = True + + if enable: + # send next forward + logger.debug('Prefetching next forward inputs.') + return await self._send_next_inputs_impl(prefill, True) + else: + return None, None + + +def build_inputs_maker(engine: 'Engine'): + """Build inputs makers.""" + config = InputMakerConfig.from_engine(engine) + return InputsMakerAsync( + executor=engine.executor, + scheduler=engine.scheduler, + adapter_manager=engine.adapter_manager, + engine_strategy=engine.engine_strategy, + sampling_strategy=engine.sampling_strategy, + model_agent_strategy=engine.model_agent_strategy, + config=config, + ) diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 268d9556dd..e58dec587e 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -191,7 +191,7 @@ def event_loop(self): else: return self._loop_task.get_loop() - def start_loop(self, loop: asyncio.Task): + def set_main_loop(self, loop: asyncio.Task): """Start main loop.""" self._loop_coro = loop diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index bbf7ff903a..f52d9b690b 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -2,6 +2,7 @@ # modify from: https://github.com/vllm-project/vllm from collections import OrderedDict +from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, List @@ -317,24 +318,42 @@ def get_block_tables(self, seqs: SeqList): """Get block table of the sequences.""" return [self.block_manager.get_block_table(seq) for seq in seqs] - def active_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.READY): + def activate_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.READY): """Lock running sequence.""" for seq in running: if seq.status == filter_status: seq.state.activate() - def deactive_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.RUNNING): + def deactivate_seqs(self, running: SeqList, filter_status: MessageStatus = MessageStatus.RUNNING): for seq in running: if seq.status == filter_status: seq.state.deactivate() + @contextmanager + def seqs_activation(self, running: SeqList): + """Context manager to activate and deactivate sequences.""" + self.activate_seqs(running, MessageStatus.READY) + try: + yield running + finally: + self.deactivate_seqs(running, MessageStatus.RUNNING) + def activate_migration_seqs(self, running: SeqList): """Lock running sequence.""" - return self.active_seqs(running, filter_status=MessageStatus.MIGRATION_READY) + return self.activate_seqs(running, filter_status=MessageStatus.MIGRATION_READY) def deactivate_migration_seqs(self, running: SeqList): """Unlock running migration.""" - return self.deactive_seqs(running, filter_status=MessageStatus.MIGRATION_RUNNING) + return self.deactivate_seqs(running, filter_status=MessageStatus.MIGRATION_RUNNING) + + @contextmanager + def seqs_migration_activation(self, running: SeqList): + """Context manager to activate and deactivate sequences.""" + self.activate_migration_seqs(running) + try: + yield running + finally: + self.deactivate_migration_seqs(running) def collect_migration_done(self): for seq in self.migration_done: @@ -343,7 +362,7 @@ def collect_migration_done(self): @property def schedule_metrics(self): return ScheduleMetrics( - active_seqs=self.num_running(), + activate_seqs=self.num_running(), waiting_seqs=self.num_waiting() + self.num_ready(), total_blocks=self.block_manager.num_gpu_blocks, free_blocks=self.block_manager.get_num_free_gpu_blocks(), From 981be9cfb6efd4c175e72bdc1d7edfb504e6edf2 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 3 Dec 2025 14:57:54 +0800 Subject: [PATCH 08/12] fix close --- lmdeploy/pytorch/engine/engine.py | 5 +-- lmdeploy/pytorch/engine/engine_loop.py | 38 ++++++++++++++------ lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py | 1 - lmdeploy/pytorch/engine/request.py | 2 +- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 7c2c30d473..fbb38683e6 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -460,8 +460,9 @@ async def async_loop(self): engine_loop.create_tasks(event_loop) await engine_loop.wait_tasks() except asyncio.CancelledError: - logger.debug('Engine main loop cancelled.') - except Exception: + logger.info('Engine main loop cancelled.') + raise + except BaseException: logger.exception('Engine main loop failed.') finally: self._loop_finally() diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 696c0d6433..fa121dd22d 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -452,11 +452,11 @@ def __task_callback(task: asyncio.Task) -> None: try: task.result() except asyncio.CancelledError: - logger.debug(f'Task <{task_name}> cancelled.') - except Exception: + logger.info(f'Task <{task_name}> cancelled.') + except BaseException: logger.exception(f'Task <{task_name}> failed') finally: - self.stop_event.set() + self.stop() self.cancel() for task in self.tasks: @@ -476,15 +476,31 @@ def create_tasks(self, event_loop: asyncio.AbstractEventLoop): self._add_loop_tasks_done_callback() - async def wait_tasks(self, timeout: Optional[float] = None): + async def wait_tasks(self): """Wait for all tasks to finish.""" - if len(self.tasks) == 0: + if not self.tasks: return - if timeout is not None: - await asyncio.wait(asyncio.gather(*self.tasks), timeout=timeout) - else: - await asyncio.gather(*self.tasks) - self.stop() + + try: + done, pending = await asyncio.wait(self.tasks, return_when=asyncio.FIRST_EXCEPTION) + + # cancel all pending tasks + for task in pending: + task.cancel() + + for task in done: + try: + task.result() + except asyncio.CancelledError: + logger.debug('Task cancelled.') + except asyncio.CancelledError: + logger.info('Engine loop wait tasks cancelled.') + raise + except BaseException: + logger.exception('Engine loop wait tasks failed.') + finally: + self.stop() + self.cancel() def stop(self): """Stop all loops.""" @@ -493,7 +509,7 @@ def stop(self): def cancel(self): """Cancel all loops.""" for task in self.tasks: - if not (task.done() or task.cancelled()): + if not task.done(): task.cancel() diff --git a/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py b/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py index 0d85ff9146..f3d40aa7c6 100644 --- a/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py +++ b/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py @@ -23,7 +23,6 @@ def _task_callback(task: asyncio.Task) -> None: task.result() except asyncio.CancelledError: logger.debug(f'Task <{task_name}> cancelled.') - return except Exception: logger.exception(f'Task <{task_name}> failed') finally: diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index e58dec587e..95f52e9a97 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -179,7 +179,7 @@ def create_loop_task(self): event_loop = asyncio.get_event_loop() assert self._loop_coro is not None, ('Please set loop task with manager.start_loop') loop_unshielded = event_loop.create_task(self._loop_coro(), name='EngineMainLoop') - self._loop_task = asyncio.shield(loop_unshielded) + self._loop_task = loop_unshielded self.requests = asyncio.Queue() return self._loop_task From bde955d80cad23b2acbbfa2cdb76db001e431e69 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 3 Dec 2025 16:25:32 +0800 Subject: [PATCH 09/12] fix ut --- tests/pytorch/engine/test_request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/engine/test_request.py b/tests/pytorch/engine/test_request.py index 811155f91f..1c3b0a175d 100644 --- a/tests/pytorch/engine/test_request.py +++ b/tests/pytorch/engine/test_request.py @@ -42,7 +42,7 @@ async def __dummy_loop(): return sender = manager.build_sender() - manager.start_loop(__dummy_loop) + manager.set_main_loop(__dummy_loop) # test not bind resp = sender.send_async(RequestType.STOP_ENGINE, None) From 099bfb901aa9e4be81890fe590166f0776463f89 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 3 Dec 2025 18:14:33 +0800 Subject: [PATCH 10/12] update end session --- lmdeploy/pytorch/engine/engine.py | 14 +++++++------ lmdeploy/pytorch/engine/engine_loop.py | 21 +++++++------------ .../{input_maker.py => inputs_maker.py} | 10 ++++----- lmdeploy/pytorch/engine/request.py | 4 ++-- lmdeploy/pytorch/messages.py | 2 ++ lmdeploy/pytorch/paging/scheduler.py | 3 +++ tests/pytorch/engine/test_request.py | 2 +- 7 files changed, 29 insertions(+), 27 deletions(-) rename lmdeploy/pytorch/engine/{input_maker.py => inputs_maker.py} (98%) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index fbb38683e6..ef208849a3 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -54,10 +54,10 @@ class InferOutput: routed_experts: torch.Tensor = None -def _build_seq_meta(cache_config: CacheConfig, strategy: Any): +def _build_seq_meta(cache_config: CacheConfig, seq_strategy: Any, sampling_strategy: Any): from lmdeploy.pytorch.messages import SequenceMeta - seq_meta = SequenceMeta(cache_config.block_size, strategy=strategy) + seq_meta = SequenceMeta(cache_config.block_size, strategy=seq_strategy, sampling_strategy=sampling_strategy) return seq_meta @@ -156,7 +156,9 @@ def __init__( self.input_processor = self.executor.get_input_processor() cache_config = self.executor.cache_config self.adapter_manager = self._build_adapter_manager(adapters) - self.seq_meta = _build_seq_meta(cache_config, strategy=self.seq_strategy) + self.seq_meta = _build_seq_meta(cache_config, + seq_strategy=self.seq_strategy, + sampling_strategy=self.sampling_strategy) self.scheduler = Scheduler(scheduler_config, cache_config, seq_meta=self.seq_meta) # engine args @@ -174,7 +176,7 @@ def __init__( self.req_manager = self._bind_request_manager() # create main thread - self.req_manager.set_main_loop(self.async_loop) + self.req_manager.set_main_loop_func(self.async_loop) self._loop_main = None # for PD Disaggregation @@ -448,11 +450,12 @@ async def async_loop(self): self._loop_main = asyncio.current_task() event_loop = asyncio.get_event_loop() - # migration task + # create engine loop engine_loop = build_engine_loop(self) self.migration_event = engine_loop.migration_event forward_event = engine_loop.forward_event + # start executor logger.info('Starting executor.') self.executor.start(forward_event) @@ -499,7 +502,6 @@ def start_loop(self): def end_session(self, session_id: int): """End session.""" if session_id in self.scheduler.sessions: - self.sampling_strategy.on_session_end(session_id) self.scheduler.end_session(session_id) return True return False diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index fa121dd22d..5b9aeb6734 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -25,34 +25,28 @@ from .engine import Engine, SeqList from .executor import ExecutorBase - from .input_maker import InputsMakerAsync + from .inputs_maker import InputsMakerAsync from .request import RequestManager logger = get_logger('lmdeploy') _EMPTY_TOKEN = np.empty((0, ), dtype=np.int64) -class CounterEvent: +class CounterEvent(asyncio.Event): def __init__(self): + super().__init__() self._counter = 0 - self._event = asyncio.Event() - - async def wait(self): - await self._event.wait() - - def is_set(self): - return self._event.is_set() def set(self): if self._counter > 0: self._counter -= 1 if self._counter == 0: - self._event.set() + super().set() def clear(self): - if self._counter == 0 and self._event.is_set(): - self._event.clear() + if self._counter == 0 and super().is_set(): + super().clear() self._counter += 1 @@ -511,11 +505,12 @@ def cancel(self): for task in self.tasks: if not task.done(): task.cancel() + self.tasks.clear() def build_engine_loop(engine: 'Engine'): """Build engine loop.""" - from .input_maker import build_inputs_maker + from .inputs_maker import build_inputs_maker config = EngineLoopConfig.from_engine(engine) inputs_maker = build_inputs_maker(engine) diff --git a/lmdeploy/pytorch/engine/input_maker.py b/lmdeploy/pytorch/engine/inputs_maker.py similarity index 98% rename from lmdeploy/pytorch/engine/input_maker.py rename to lmdeploy/pytorch/engine/inputs_maker.py index e0167c3825..9a67179191 100644 --- a/lmdeploy/pytorch/engine/input_maker.py +++ b/lmdeploy/pytorch/engine/inputs_maker.py @@ -38,7 +38,7 @@ def _tensorlize_block_offsets(block_offsets, dtype=torch.int32): @dataclass -class InputMakerConfig: +class InputsMakerConfig: """Input maker config. This config is added for Dependency Injection @@ -53,7 +53,7 @@ class InputMakerConfig: @staticmethod def from_engine(engine: 'Engine'): cache_config = engine.cache_config - return InputMakerConfig( + return InputsMakerConfig( spec_decoding=engine.specdecode_config is not None, max_batches=cache_config.max_batches, max_prefill_token_num=cache_config.max_prefill_token_num, @@ -73,7 +73,7 @@ def __init__( engine_strategy: 'EngineStrategy', sampling_strategy: 'SamplingStrategy', model_agent_strategy: 'ModelAgentStrategy', - config: InputMakerConfig, + config: InputsMakerConfig, ): self.executor = executor self.scheduler = scheduler @@ -92,7 +92,7 @@ def __init__( self.next_is_prefill = True self.forward_inputs = None - def _init_do_prefill(self, config: InputMakerConfig): + def _init_do_prefill(self, config: InputsMakerConfig): if config.role == EngineRole.Prefill: self.do_prefill = self.do_prefill_pnode elif config.dp == 1: @@ -403,7 +403,7 @@ async def prefetch_next_inputs(self): def build_inputs_maker(engine: 'Engine'): """Build inputs makers.""" - config = InputMakerConfig.from_engine(engine) + config = InputsMakerConfig.from_engine(engine) return InputsMakerAsync( executor=engine.executor, scheduler=engine.scheduler, diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 95f52e9a97..cb1cc3e8a5 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -3,7 +3,7 @@ import enum import logging from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable, Dict, List +from typing import Any, Awaitable, Callable, Coroutine, Dict, List from lmdeploy.messages import RequestMetrics, ResponseType from lmdeploy.utils import get_logger @@ -191,7 +191,7 @@ def event_loop(self): else: return self._loop_task.get_loop() - def set_main_loop(self, loop: asyncio.Task): + def set_main_loop_func(self, loop: Callable[[Coroutine], asyncio.Task]): """Start main loop.""" self._loop_coro = loop diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py index 699988f9b6..35c75647bd 100644 --- a/lmdeploy/pytorch/messages.py +++ b/lmdeploy/pytorch/messages.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from lmdeploy.pytorch.paging.scheduler import Scheduler from lmdeploy.pytorch.paging.seq_states.states import StateBase + from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy from lmdeploy.pytorch.strategies.base.sequence import SequenceStrategy logger = get_logger('lmdeploy') @@ -173,6 +174,7 @@ class SequenceMeta: """Meta data shared by all sequence.""" block_size: int strategy: 'SequenceStrategy' = None + sampling_strategy: 'SamplingStrategy' = None class SequenceManager: diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index f52d9b690b..acaf70fcb8 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -63,6 +63,7 @@ def __init__( self.eviction_helper = build_eviction_helper(self, self.scheduler_config.eviction_type) seq_meta = seq_meta or SequenceMeta(self.cache_config.block_size) + self.seq_meta = seq_meta self.seq_manager = SequenceManager(seq_meta) @staticmethod @@ -302,6 +303,8 @@ def end_session(self, session_id: int): Args: session_id (int): The session id. """ + if self.seq_meta.sampling_strategy is not None: + self.seq_meta.sampling_strategy.on_session_end(session_id) session = self.sessions[session_id] seqs = list(session.sequences.values()) for seq in seqs: diff --git a/tests/pytorch/engine/test_request.py b/tests/pytorch/engine/test_request.py index 1c3b0a175d..41771b1a3c 100644 --- a/tests/pytorch/engine/test_request.py +++ b/tests/pytorch/engine/test_request.py @@ -42,7 +42,7 @@ async def __dummy_loop(): return sender = manager.build_sender() - manager.set_main_loop(__dummy_loop) + manager.set_main_loop_func(__dummy_loop) # test not bind resp = sender.send_async(RequestType.STOP_ENGINE, None) From d4268eb0838abc00d542db0248201e1c7963bcf0 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 10 Dec 2025 12:32:36 +0800 Subject: [PATCH 11/12] fix profiler --- lmdeploy/pytorch/engine/model_agent.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 49e7d671f0..677cb1d486 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -124,10 +124,7 @@ def __init__(self, dist_ctx: DistContext, stream: torch.Stream): self.dp = dist_ctx.dist_config.dp self.stream = stream self.profiler = None - if self.dp == 1: - self.name = f'rank[{self.rank}]' - else: - self.name = f'dp_rank[{self.dp_rank}]' + self.name = f'rank[{self.rank}]' self.delay = envs.torch_profile_delay self.duration = envs.torch_profile_duration @@ -166,7 +163,7 @@ def dump(self): try: self.profiler.stop() - rank = self.rank if self.dp == 1 else self.dp_rank + rank = self.rank dump_path = f'{self.prefix}{rank}.json' self.profiler.export_chrome_trace(dump_path) logger.warning(f'Profiler {self.name} dump to {dump_path}.') From a370912532a5001e20b9382479cb25e8eaddc7e7 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 11 Dec 2025 11:50:03 +0800 Subject: [PATCH 12/12] typo fix --- lmdeploy/pytorch/paging/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py index 510ad7c7dd..ce2c14a8ae 100644 --- a/lmdeploy/pytorch/paging/scheduler.py +++ b/lmdeploy/pytorch/paging/scheduler.py @@ -366,7 +366,7 @@ def collect_migration_done(self): @property def schedule_metrics(self): return ScheduleMetrics( - activate_seqs=self.num_running(), + active_seqs=self.num_running(), waiting_seqs=self.num_waiting() + self.num_ready(), total_blocks=self.block_manager.num_gpu_blocks, free_blocks=self.block_manager.get_num_free_gpu_blocks(),