@@ -84,11 +84,30 @@ def __init__(self, engine: Engine):
8484 self .req_sender = engine .req_manager .build_sender ()
8585
8686 self .max_input_len = self .engine .max_session_len
87+ self ._enable_transfer_obj_ref = engine .engine_config .enable_transfer_obj_ref and \
88+ engine .engine_config .distributed_executor_backend == 'ray'
8789
8890 def __del__ (self ):
8991 """Destructor."""
9092 self .engine .req_manager .senders .pop (self .req_sender .sender_id )
9193
94+ def _get_extra_outputs (self , resp : Response ):
95+ """Get extra outputs."""
96+ outputs = dict (routed_experts = None )
97+ routed_experts = resp .data .get ('routed_experts' , None ) if resp .data else None
98+ if routed_experts is not None and resp .type in [ResponseType .FINISH , ResponseType .CANCEL ]:
99+ if self ._enable_transfer_obj_ref :
100+ import base64
101+
102+ import ray
103+
104+ ref = ray .put (routed_experts )
105+ data = ray .cloudpickle .dumps (ref )
106+ outputs ['routed_experts' ] = base64 .b64encode (data ).decode ('utf-8' )
107+ else :
108+ outputs ['routed_experts' ] = routed_experts
109+ return outputs
110+
92111 async def _async_try_add_session (self , session_id : int ):
93112 """Add new session.
94113
@@ -152,27 +171,28 @@ async def async_stream_infer(self,
152171 cache_block_ids = resp .data .get ('cache_block_ids' , None ) if resp .data else None
153172 req_metrics = resp .data .get ('req_metrics' , None ) if resp .data else None
154173 logprobs = resp .data .pop ('logprobs' , None ) if resp .data else None
155- routed_experts = resp .data .get ('routed_experts' , None ) if resp .data else None
174+ extra_outputs = self ._get_extra_outputs (resp )
175+ routed_experts = extra_outputs .get ('routed_experts' , None )
156176
157177 if resp .type == ResponseType .SUCCESS :
158- token_ids = resp .data ['token_ids' ]. tolist ()
178+ token_ids = resp .data ['token_ids' ]
159179 num_ids = len (token_ids ) - output_offset
160180 logger .debug (f'session[{ session_id } ] success: num_out_ids={ num_ids } .' )
161181 yield EngineOutput (resp .type ,
162- token_ids [output_offset :],
182+ token_ids [output_offset :]. tolist () ,
163183 cache_block_ids = cache_block_ids ,
164184 req_metrics = req_metrics ,
165185 routed_experts = routed_experts ,
166186 logprobs = logprobs )
167187 output_offset = len (token_ids )
168- elif resp .type == ResponseType .FINISH :
188+ elif resp .type in ( ResponseType .FINISH , ResponseType . CANCEL ) :
169189 resp_data = resp .data
170- token_ids = resp_data ['token_ids' ]. tolist ()
190+ token_ids = resp_data ['token_ids' ]
171191 logits = resp_data ['logits' ]
172192 num_ids = len (token_ids ) - output_offset
173193 logger .debug (f'session[{ session_id } ] finish: num_out_ids={ num_ids } .' )
174194 yield EngineOutput (resp .type ,
175- token_ids [output_offset :],
195+ token_ids [output_offset :]. tolist () ,
176196 logits = logits ,
177197 cache_block_ids = cache_block_ids ,
178198 req_metrics = req_metrics ,
0 commit comments