55import json
66import time
77import traceback
8- from typing import List , Dict , Union , Any , Optional
8+ from typing import List , Dict , Union , Any , Optional , Callable , Protocol
99
1010import pydantic
1111from pydantic import Field , BaseModel
1212from langchain_core .callbacks .base import BaseCallbackHandler
13- from langchain_core .outputs import LLMResult
13+ from langchain_core .outputs import LLMResult , ChatGeneration
1414from langchain_core .agents import AgentFinish , AgentAction
1515from langchain_core .prompt_values import PromptValue , ChatPromptValue
16- from langchain_core .messages import BaseMessage , AIMessageChunk
16+ from langchain_core .messages import BaseMessage , AIMessageChunk , AIMessage
1717from langchain_core .prompts import AIMessagePromptTemplate , HumanMessagePromptTemplate , SystemMessagePromptTemplate
1818from langchain_core .outputs import ChatGenerationChunk , GenerationChunk
1919
2929
3030class LoopTracer :
3131 @classmethod
32- def get_callback_handler (cls , client : Client = None ):
32+ def get_callback_handler (
33+ cls ,
34+ client : Client = None ,
35+ modify_name_fn : Optional [Callable [[str ], str ]] = None ,
36+ add_tags_fn : Optional [Callable [[str ], Dict [str , Any ]]] = None ,
37+ tags : Dict [str , Any ] = None ,
38+ ):
3339 """
3440 Do not hold it for a long time, get a new callback_handler for each request.
41+ modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42+ add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
3543 """
3644 global _trace_callback_client
3745 if client :
3846 _trace_callback_client = client
3947 else :
4048 _trace_callback_client = get_default_client ()
4149
42- return LoopTraceCallbackHandler ()
50+ return LoopTraceCallbackHandler (modify_name_fn , add_tags_fn , tags )
4351
4452
4553class LoopTraceCallbackHandler (BaseCallbackHandler ):
46- def __init__ (self ):
54+ def __init__ (
55+ self ,
56+ name_fn : Optional [Callable [[str ], str ]] = None ,
57+ tags_fn : Optional [Callable [[str ], Dict [str , Any ]]] = None ,
58+ tags : Dict [str , Any ] = None ,
59+ ):
4760 super ().__init__ ()
4861 self ._space_id = _trace_callback_client .workspace_id
4962 self .run_map : Dict [str , Run ] = {}
63+ self .name_fn = name_fn
64+ self .tags_fn = tags_fn
65+ self ._tags = tags if tags else {}
5066
5167 def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
5268 span_tags = {}
@@ -98,32 +114,26 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
98114 try :
99115 # set output span_tag
100116 flow_span .set_tags ({'output' : ModelTraceOutput (response .generations ).to_json ()})
117+ # set model tags
118+ tags = self ._get_model_tags (response , ** kwargs )
119+ if tags :
120+ self ._set_span_tags (flow_span , tags , need_convert_tag_value = False )
101121 except Exception as e :
102122 flow_span .set_error (e )
103- # calculate token usage,and set span_tag
104- if response .llm_output is not None and 'token_usage' in response .llm_output and response .llm_output ['token_usage' ]:
105- self ._set_span_tags (flow_span , response .llm_output ['token_usage' ], need_convert_tag_value = False )
106- else :
107- try :
108- run_info = self .run_map [str (kwargs ['run_id' ])]
109- if run_info is not None and run_info .model_meta is not None :
110- model_name = run_info .model_meta .model_name
111- input_messages = run_info .model_meta .message
112- flow_span .set_input_tokens (calc_token_usage (input_messages , model_name ))
113- flow_span .set_output_tokens (calc_token_usage (response , model_name ))
114- except Exception as e :
115- flow_span .set_error (e )
116123 # finish flow_span
117- flow_span . finish ( )
124+ self . _end_flow_span ( flow_span )
118125
119126 def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
120127 flow_span = None
121128 try :
122129 if kwargs .get ('run_type' , '' ) == 'prompt' or kwargs .get ('name' , '' ) == 'ChatPromptTemplate' :
123- flow_span = self ._new_flow_span (kwargs ['name' ], kwargs [ 'name' ] , ** kwargs )
130+ flow_span = self ._new_flow_span (kwargs ['name' ], 'prompt' , ** kwargs )
124131 self ._on_prompt_start (flow_span , serialized , inputs , ** kwargs )
125132 else :
126- flow_span = self ._new_flow_span (kwargs ['name' ], kwargs ['name' ], ** kwargs )
133+ span_type = 'chain'
134+ if kwargs ['name' ] == 'LangGraph' : # LangGraph is agent span_type,for trajectory evaluation aggregate to an agent
135+ span_type = 'agent'
136+ flow_span = self ._new_flow_span (kwargs ['name' ], span_type , ** kwargs )
127137 flow_span .set_tags ({'input' : _convert_2_json (inputs )})
128138 except Exception as e :
129139 if flow_span is not None :
@@ -142,7 +152,7 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
142152 flow_span .set_tags ({'output' : _convert_2_json (outputs )})
143153 except Exception as e :
144154 flow_span .set_error (e )
145- flow_span . finish ( )
155+ self . _end_flow_span ( flow_span )
146156
147157 def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
148158 flow_span = self ._get_flow_span (** kwargs )
@@ -151,7 +161,7 @@ def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: A
151161 flow_span = self ._new_flow_span (span_name , 'chain_error' , ** kwargs )
152162 flow_span .set_error (error )
153163 flow_span .set_tags ({'error_trace' : traceback .format_exc ()})
154- flow_span . finish ( )
164+ self . _end_flow_span ( flow_span )
155165
156166 def on_tool_start (
157167 self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any
@@ -167,7 +177,7 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
167177 flow_span .set_tags ({'output' : _convert_2_json (output )})
168178 except Exception as e :
169179 flow_span .set_error (e )
170- flow_span . finish ( )
180+ self . _end_flow_span ( flow_span )
171181
172182 def on_tool_error (
173183 self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
@@ -178,7 +188,7 @@ def on_tool_error(
178188 flow_span = self ._new_flow_span (span_name , 'tool_error' , ** kwargs )
179189 flow_span .set_error (error )
180190 flow_span .set_tags ({'error_trace' : traceback .format_exc ()})
181- flow_span . finish ( )
191+ self . _end_flow_span ( flow_span )
182192
183193 def on_text (self , text : str , ** kwargs : Any ) -> Any :
184194 """Run on arbitrary text."""
@@ -189,6 +199,67 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
189199 def on_agent_finish (self , finish : AgentFinish , ** kwargs : Any ) -> Any :
190200 return
191201
202+ def _end_flow_span (self , span : Span ):
203+ span .set_tags (self ._tags )
204+ span .finish ()
205+
206+ def _get_model_tags (self , response : LLMResult , ** kwargs : Any ) -> Dict [str , Any ]:
207+ return self ._get_model_token_tags (response , ** kwargs )
208+
209+ def _get_model_token_tags (self , response : LLMResult , ** kwargs : Any ) -> Dict [str , Any ]:
210+ result = {}
211+ is_get_from_langchain = False
212+ if response .llm_output is not None and 'token_usage' in response .llm_output and response .llm_output [
213+ 'token_usage' ]:
214+ is_get_from_langchain = True
215+ result ['input_tokens' ] = response .llm_output .get ('token_usage' , {}).get ('prompt_tokens' , 0 )
216+ result ['output_tokens' ] = response .llm_output .get ('token_usage' , {}).get ('completion_tokens' , 0 )
217+ result ['tokens' ] = result ['input_tokens' ] + result ['output_tokens' ]
218+ reasoning_tokens = response .llm_output .get ('token_usage' , {}).get ('completion_tokens_details' , {}).get (
219+ 'reasoning_tokens' , 0 )
220+ if reasoning_tokens :
221+ result ['reasoning_tokens' ] = reasoning_tokens
222+ input_cached_tokens = response .llm_output .get ('token_usage' , {}).get ('prompt_tokens_details' , {}).get (
223+ 'cached_tokens' , 0 )
224+ if input_cached_tokens :
225+ result ['input_cached_tokens' ] = input_cached_tokens
226+ elif response .generations is not None and len (response .generations ) > 0 and response .generations [0 ] is not None :
227+ for i , generation in enumerate (response .generations [0 ]):
228+ if isinstance (generation , ChatGeneration ) and isinstance (generation .message ,(AIMessageChunk , AIMessage )) and generation .message .usage_metadata :
229+ is_get_from_langchain = True
230+ result ['input_tokens' ] = generation .message .usage_metadata .get ('input_tokens' , 0 )
231+ result ['output_tokens' ] = generation .message .usage_metadata .get ('output_tokens' , 0 )
232+ result ['tokens' ] = result ['input_tokens' ] + result ['output_tokens' ]
233+ if generation .message .usage_metadata .get ('output_token_details' , {}):
234+ reasoning_tokens = generation .message .usage_metadata .get ('output_token_details' , {}).get ('reasoning' , 0 )
235+ if reasoning_tokens :
236+ result ['reasoning_tokens' ] = reasoning_tokens
237+ if generation .message .usage_metadata .get ('input_token_details' , {}):
238+ input_read_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get ('cache_read' , 0 )
239+ if input_read_cached_tokens :
240+ result ['input_cached_tokens' ] = input_read_cached_tokens
241+ input_creation_cached_tokens = generation .message .usage_metadata .get ('input_token_details' , {}).get ('cache_creation' , 0 )
242+ if input_creation_cached_tokens :
243+ result ['input_creation_cached_tokens' ] = input_creation_cached_tokens
244+ if is_get_from_langchain :
245+ return result
246+ else :
247+ try :
248+ run_info = self .run_map [str (kwargs ['run_id' ])]
249+ if run_info is not None and run_info .model_meta is not None :
250+ model_name = run_info .model_meta .model_name
251+ input_messages = run_info .model_meta .message
252+ token_usage = {
253+ 'input_tokens' : calc_token_usage (input_messages , model_name ),
254+ 'output_tokens' : calc_token_usage (response , model_name ),
255+ 'tokens' : 0
256+ }
257+ token_usage ['tokens' ] = token_usage ['input_tokens' ] + token_usage ['output_tokens' ]
258+ return token_usage
259+ except Exception as e :
260+ span_tags = {'error_info' : repr (e ), 'error_trace' : traceback .format_exc ()}
261+ return span_tags
262+
192263 def _on_prompt_start (self , flow_span , serialized : Dict [str , Any ], inputs : (Dict [str , Any ], str ), ** kwargs : Any ) -> None :
193264 # get inputs
194265 params : List [Argument ] = []
@@ -234,18 +305,38 @@ def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[
234305 flow_span .set_tags ({'prompt_version' : kwargs ['metadata' ]['lc_hub_commit_hash' ]})
235306 flow_span .set_tags ({'prompt_provider' : 'langsmith' })
236307
237- def _new_flow_span (self , span_name : str , span_type : str , ** kwargs : Any ) -> Span :
308+ def _new_flow_span (self , node_name : str , span_type : str , ** kwargs : Any ) -> Span :
238309 span_type = _span_type_mapping (span_type )
310+ span_name = node_name
239311 # set parent span
240312 parent_span : Span = None
241313 if 'parent_run_id' in kwargs and kwargs ['parent_run_id' ] is not None and str (kwargs ['parent_run_id' ]) in self .run_map :
242314 parent_span = self .run_map [str (kwargs ['parent_run_id' ])].span
315+ # modify name
316+ error_tag = {}
317+ try :
318+ if self .name_fn :
319+ name = self .name_fn (node_name )
320+ if name :
321+ span_name = name
322+ except Exception as e :
323+ error_tag = {'error_info' : f'name_fn error { repr (e )} ' , 'error_trace' : traceback .format_exc ()}
243324 # new span
244325 flow_span = _trace_callback_client .start_span (span_name , span_type , child_of = parent_span )
245326 run_id = str (kwargs ['run_id' ])
246327 self .run_map [run_id ] = Run (run_id , flow_span , span_type )
247328 # set default tags
248329 flow_span .set_runtime (RuntimeInfo ())
330+ # set extra tags
331+ try :
332+ if self .tags_fn :
333+ tags = self .tags_fn (node_name )
334+ if isinstance (tags , dict ):
335+ flow_span .set_tags (tags )
336+ except Exception as e :
337+ error_tag = {'error_info' : f'tags_fn error { repr (e )} ' , 'error_trace' : traceback .format_exc ()}
338+ if error_tag :
339+ flow_span .set_tags (error_tag )
249340 return flow_span
250341
251342 def _get_flow_span (self , ** kwargs : Any ) -> Span :
@@ -417,7 +508,7 @@ def _convert_inputs(inputs: Any) -> Any:
417508 for each in inputs :
418509 format_inputs .append (_convert_inputs (each ))
419510 return format_inputs
420- if isinstance (inputs , AIMessageChunk ):
511+ if isinstance (inputs , ( AIMessageChunk , AIMessage ) ):
421512 """
422513 Must be before BaseMessage.
423514 """
0 commit comments