11# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
22# SPDX-License-Identifier: MIT
3+ import time
34from typing import Optional , Callable , Any , overload , Dict , Generic , Iterator , TypeVar , List , cast , AsyncIterator
45from functools import wraps
56
@@ -82,7 +83,7 @@ def sync_wrapper(*args: Any, **kwargs: Any):
8283 output = res
8384 if process_outputs :
8485 output = process_outputs (output )
85-
86+ inject_inner_token ( span , output )
8687 span .set_output (output )
8788 except StopIteration :
8889 pass
@@ -116,6 +117,7 @@ async def async_wrapper(*args: Any, **kwargs: Any):
116117 output = res
117118 if process_outputs :
118119 output = process_outputs (output )
120+ inject_inner_token (span , output )
119121 span .set_output (output )
120122 except StopIteration :
121123 pass
@@ -227,7 +229,7 @@ def sync_stream_wrapper(*args: Any, **kwargs: Any):
227229 res = func (* args , ** kwargs )
228230 output = res
229231 if hasattr (output , "__iter__" ):
230- return _CozeLoopTraceStream (output , span , process_iterator_outputs )
232+ return _CozeLoopTraceStream (output , span , process_iterator_outputs , span_type )
231233 if process_outputs :
232234 output = process_outputs (output )
233235
@@ -262,7 +264,7 @@ async def async_stream_wrapper(*args: Any, **kwargs: Any):
262264 res = await func (* args , ** kwargs )
263265 output = res
264266 if hasattr (output , "__aiter__" ):
265- return _CozeLoopAsyncTraceStream (output , span , process_iterator_outputs )
267+ return _CozeLoopAsyncTraceStream (output , span , process_iterator_outputs , span_type )
266268 if process_outputs :
267269 output = process_outputs (output )
268270 span .set_output (output )
@@ -289,7 +291,6 @@ async def async_stream_wrapper(*args: Any, **kwargs: Any):
289291 if not hasattr (res , "__aiter__" ) and res :
290292 return res
291293
292-
293294 if is_async_gen_func (func ):
294295 return async_gen_wrapper
295296 if is_gen_func (func ):
@@ -317,11 +318,14 @@ def __init__(
317318 stream : Iterator [S ],
318319 span : Span ,
319320 process_iterator_outputs : Optional [Callable [[Any ], Any ]] = None ,
321+ span_type : str = "" ,
320322 ):
321323 self .__stream__ = stream
322324 self .__span = span
323325 self .__output__ : list [S ] = []
324326 self .__process_iterator_outputs = process_iterator_outputs
327+ self .__is_set_start_time_first_token : bool = False
328+ self .__span_type = span_type
325329
326330 def __next__ (self ) -> S :
327331 try :
@@ -360,13 +364,17 @@ def __streamer__(
360364 while True :
361365 s = next (temp_stream )
362366 self .__output__ .append (s )
367+ if not self .__is_set_start_time_first_token and self .__span_type == "model" :
368+ self .__span .set_start_time_first_resp (time .time_ns () // 1_000 )
369+ self .__is_set_start_time_first_token = True
363370 yield s
364371 except StopIteration as e :
365372 return e
366373
367374 def __end__ (self , err : Exception = None ):
368375 if self .__process_iterator_outputs :
369376 self .__output__ = self .__process_iterator_outputs (self .__output__ )
377+ inject_inner_token (self .__span , self .__output__ )
370378 self .__span .set_output (self .__output__ )
371379 if err :
372380 self .__span .set_error (err )
@@ -379,17 +387,21 @@ def __init__(
379387 stream : AsyncIterator [S ],
380388 span : Span ,
381389 process_iterator_outputs : Optional [Callable [[Any ], Any ]] = None ,
390+ span_type : str = "" ,
382391 ):
383392 self .__stream__ = stream
384393 self .__span = span
385394 self .__output__ : list [S ] = []
386395 self .__process_iterator_outputs = process_iterator_outputs
396+ self .__is_set_start_time_first_token : bool = False
397+ self .__span_type = span_type
387398
388399 async def _aend (self , error : Optional [Exception ] = None ):
389400 if error :
390401 self .__span .set_error (error )
391402 if self .__process_iterator_outputs :
392403 self .__output__ = self .__process_iterator_outputs (self .__output__ )
404+ inject_inner_token (self .__span , self .__output__ )
393405 self .__span .set_output (self .__output__ )
394406 self .__span .finish ()
395407
@@ -433,6 +445,17 @@ async def __async_streamer__(
433445 while True :
434446 s = await temp_stream .__anext__ ()
435447 self .__output__ .append (s )
448+ if not self .__is_set_start_time_first_token and self .__span_type == "model" :
449+ self .__span .set_start_time_first_resp (time .time_ns () // 1_000 )
450+ self .__is_set_start_time_first_token = True
436451 yield s
437452 except StopIteration :
438453 pass
454+
455+
456+ def inject_inner_token (span : Span , src ):
457+ if isinstance (src , dict ) and src .get ("_inner_tokens_dict" ):
458+ if input_tokens := src .get ("_inner_tokens_dict" ).get ("input_tokens" , 0 ):
459+ span .set_input_tokens (input_tokens )
460+ if output_tokens := src .get ("_inner_tokens_dict" ).get ("output_tokens" , 0 ):
461+ span .set_output_tokens (output_tokens )
0 commit comments