Skip to content

Commit 1426ea4

Browse files
authored
Return routed experts when request canceled (#4197)
1 parent d38f032 commit 1426ea4

File tree

5 files changed

+28
-23
lines changed

5 files changed

+28
-23
lines changed

lmdeploy/cli/serve.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def api_server(args):
237237
dllm_denoising_steps=args.dllm_denoising_steps,
238238
dllm_confidence_threshold=args.dllm_confidence_threshold,
239239
enable_return_routed_experts=args.enable_return_routed_experts,
240+
distributed_executor_backend=args.distributed_executor_backend,
240241
)
241242
else:
242243
from lmdeploy.messages import TurbomindEngineConfig

lmdeploy/pytorch/engine/engine.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -927,18 +927,14 @@ def _make_infer_outputs(
927927
num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1
928928
spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens)
929929
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info)
930-
routed_experts = msg.routed_experts if msg.return_routed_experts and finish else None
931-
if routed_experts is not None and self.engine_config.enable_transfer_obj_ref:
932-
# only serialize for api server
933-
routed_experts = self.executor.serialize(routed_experts)
934930
out = InferOutput(session_id=session_id,
935931
resp=msg.resp,
936932
finish=finish,
937933
token_ids=token_ids,
938934
cache_block_ids=cache_block_ids,
939935
req_metrics=req_metrics,
940936
logprobs=cur_logprobs,
941-
routed_experts=routed_experts)
937+
routed_experts=msg.routed_experts)
942938
outputs[session_id] = out
943939

944940
if msg.return_logits:

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,30 @@ def __init__(self, engine: Engine):
8484
self.req_sender = engine.req_manager.build_sender()
8585

8686
self.max_input_len = self.engine.max_session_len
87+
self._enable_transfer_obj_ref = engine.engine_config.enable_transfer_obj_ref and \
88+
engine.engine_config.distributed_executor_backend == 'ray'
8789

8890
def __del__(self):
8991
"""Destructor."""
9092
self.engine.req_manager.senders.pop(self.req_sender.sender_id)
9193

94+
def _get_extra_outputs(self, resp: Response):
95+
"""Get extra outputs."""
96+
outputs = dict(routed_experts=None)
97+
routed_experts = resp.data.get('routed_experts', None) if resp.data else None
98+
if routed_experts is not None and resp.type in [ResponseType.FINISH, ResponseType.CANCEL]:
99+
if self._enable_transfer_obj_ref:
100+
import base64
101+
102+
import ray
103+
104+
ref = ray.put(routed_experts)
105+
data = ray.cloudpickle.dumps(ref)
106+
outputs['routed_experts'] = base64.b64encode(data).decode('utf-8')
107+
else:
108+
outputs['routed_experts'] = routed_experts
109+
return outputs
110+
92111
async def _async_try_add_session(self, session_id: int):
93112
"""Add new session.
94113
@@ -152,27 +171,28 @@ async def async_stream_infer(self,
152171
cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
153172
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
154173
logprobs = resp.data.pop('logprobs', None) if resp.data else None
155-
routed_experts = resp.data.get('routed_experts', None) if resp.data else None
174+
extra_outputs = self._get_extra_outputs(resp)
175+
routed_experts = extra_outputs.get('routed_experts', None)
156176

157177
if resp.type == ResponseType.SUCCESS:
158-
token_ids = resp.data['token_ids'].tolist()
178+
token_ids = resp.data['token_ids']
159179
num_ids = len(token_ids) - output_offset
160180
logger.debug(f'session[{session_id}] success: num_out_ids={num_ids}.')
161181
yield EngineOutput(resp.type,
162-
token_ids[output_offset:],
182+
token_ids[output_offset:].tolist(),
163183
cache_block_ids=cache_block_ids,
164184
req_metrics=req_metrics,
165185
routed_experts=routed_experts,
166186
logprobs=logprobs)
167187
output_offset = len(token_ids)
168-
elif resp.type == ResponseType.FINISH:
188+
elif resp.type in (ResponseType.FINISH, ResponseType.CANCEL):
169189
resp_data = resp.data
170-
token_ids = resp_data['token_ids'].tolist()
190+
token_ids = resp_data['token_ids']
171191
logits = resp_data['logits']
172192
num_ids = len(token_ids) - output_offset
173193
logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.')
174194
yield EngineOutput(resp.type,
175-
token_ids[output_offset:],
195+
token_ids[output_offset:].tolist(),
176196
logits=logits,
177197
cache_block_ids=cache_block_ids,
178198
req_metrics=req_metrics,

lmdeploy/pytorch/engine/executor/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,6 @@ def release(self):
102102
"""Release resources."""
103103
raise NotImplementedError('Not Implemented.')
104104

105-
def serialize(self, obj):
106-
"""Serialize obj."""
107-
return obj
108-
109105
async def forward_async(self, inputs):
110106
"""Start forward."""
111107
raise NotImplementedError('Not Implemented')

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import asyncio
3-
import base64
43
import contextlib
54
import json
65
import os
@@ -354,13 +353,6 @@ def wakeup(self, tags: Optional[List[str]] = None):
354353
self.update_configs()
355354
self.collective_rpc('wakeup', (tags, ))
356355

357-
def serialize(self, obj) -> str:
358-
"""Serialize obj."""
359-
ref = ray.put(obj)
360-
data = ray.cloudpickle.dumps(ref)
361-
data = base64.b64encode(data).decode('utf-8')
362-
return data
363-
364356
def get_input_processor(self):
365357
"""Build cache engine."""
366358
return ray.get(self.workers[0].get_input_processor.remote())

0 commit comments

Comments
 (0)