Skip to content

Commit 9455350

Browse files
committed
support langchain callback set tags and name
1 parent 64d0381 commit 9455350

File tree

6 files changed

+208
-47
lines changed

6 files changed

+208
-47
lines changed

CHANGLOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## [0.1.19] - 2025-11-10
2+
### Added
3+
- langchain callback support langchain V1
4+
- langchain callback support set tag and name
5+
16
## [0.1.19] - 2025-11-10
27
### Fixed
38
- fix baggage escape problem

cozeloop/integration/langchain/trace_callback.py

Lines changed: 119 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
import json
66
import time
77
import traceback
8-
from typing import List, Dict, Union, Any, Optional
8+
from typing import List, Dict, Union, Any, Optional, Callable, Protocol
99

1010
import pydantic
1111
from pydantic import Field, BaseModel
1212
from langchain_core.callbacks.base import BaseCallbackHandler
13-
from langchain_core.outputs import LLMResult
13+
from langchain_core.outputs import LLMResult, ChatGeneration
1414
from langchain_core.agents import AgentFinish, AgentAction
1515
from langchain_core.prompt_values import PromptValue, ChatPromptValue
16-
from langchain_core.messages import BaseMessage, AIMessageChunk
16+
from langchain_core.messages import BaseMessage, AIMessageChunk, AIMessage
1717
from langchain_core.prompts import AIMessagePromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
1818
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
1919

@@ -29,24 +29,40 @@
2929

3030
class LoopTracer:
3131
@classmethod
32-
def get_callback_handler(cls, client: Client = None):
32+
def get_callback_handler(
33+
cls,
34+
client: Client = None,
35+
modify_name_fn: Optional[Callable[[str], str]] = None,
36+
add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
37+
tags: Dict[str, Any] = None,
38+
):
3339
"""
3440
Do not hold it for a long time, get a new callback_handler for each request.
41+
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42+
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
3543
"""
3644
global _trace_callback_client
3745
if client:
3846
_trace_callback_client = client
3947
else:
4048
_trace_callback_client = get_default_client()
4149

42-
return LoopTraceCallbackHandler()
50+
return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags)
4351

4452

4553
class LoopTraceCallbackHandler(BaseCallbackHandler):
46-
def __init__(self):
54+
def __init__(
55+
self,
56+
name_fn: Optional[Callable[[str], str]] = None,
57+
tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
58+
tags: Dict[str, Any] = None,
59+
):
4760
super().__init__()
4861
self._space_id = _trace_callback_client.workspace_id
4962
self.run_map: Dict[str, Run] = {}
63+
self.name_fn = name_fn
64+
self.tags_fn = tags_fn
65+
self._tags = tags if tags else {}
5066

5167
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
5268
span_tags = {}
@@ -98,32 +114,26 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
98114
try:
99115
# set output span_tag
100116
flow_span.set_tags({'output': ModelTraceOutput(response.generations).to_json()})
117+
# set model tags
118+
tags = self._get_model_tags(response, **kwargs)
119+
if tags:
120+
self._set_span_tags(flow_span, tags, need_convert_tag_value=False)
101121
except Exception as e:
102122
flow_span.set_error(e)
103-
# calculate token usage,and set span_tag
104-
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output['token_usage']:
105-
self._set_span_tags(flow_span, response.llm_output['token_usage'], need_convert_tag_value=False)
106-
else:
107-
try:
108-
run_info = self.run_map[str(kwargs['run_id'])]
109-
if run_info is not None and run_info.model_meta is not None:
110-
model_name = run_info.model_meta.model_name
111-
input_messages = run_info.model_meta.message
112-
flow_span.set_input_tokens(calc_token_usage(input_messages, model_name))
113-
flow_span.set_output_tokens(calc_token_usage(response, model_name))
114-
except Exception as e:
115-
flow_span.set_error(e)
116123
# finish flow_span
117-
flow_span.finish()
124+
self._end_flow_span(flow_span)
118125

119126
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
120127
flow_span = None
121128
try:
122129
if kwargs.get('run_type', '') == 'prompt' or kwargs.get('name', '') == 'ChatPromptTemplate':
123-
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
130+
flow_span = self._new_flow_span(kwargs['name'], 'prompt', **kwargs)
124131
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
125132
else:
126-
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
133+
span_type = 'chain'
134+
if kwargs['name'] == 'LangGraph': # LangGraph is agent span_type,for trajectory evaluation aggregate to an agent
135+
span_type = 'agent'
136+
flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs)
127137
flow_span.set_tags({'input': _convert_2_json(inputs)})
128138
except Exception as e:
129139
if flow_span is not None:
@@ -142,7 +152,7 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
142152
flow_span.set_tags({'output': _convert_2_json(outputs)})
143153
except Exception as e:
144154
flow_span.set_error(e)
145-
flow_span.finish()
155+
self._end_flow_span(flow_span)
146156

147157
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
148158
flow_span = self._get_flow_span(**kwargs)
@@ -151,7 +161,7 @@ def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: A
151161
flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs)
152162
flow_span.set_error(error)
153163
flow_span.set_tags({'error_trace': traceback.format_exc()})
154-
flow_span.finish()
164+
self._end_flow_span(flow_span)
155165

156166
def on_tool_start(
157167
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
@@ -167,7 +177,7 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
167177
flow_span.set_tags({'output': _convert_2_json(output)})
168178
except Exception as e:
169179
flow_span.set_error(e)
170-
flow_span.finish()
180+
self._end_flow_span(flow_span)
171181

172182
def on_tool_error(
173183
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -178,7 +188,7 @@ def on_tool_error(
178188
flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs)
179189
flow_span.set_error(error)
180190
flow_span.set_tags({'error_trace': traceback.format_exc()})
181-
flow_span.finish()
191+
self._end_flow_span(flow_span)
182192

183193
def on_text(self, text: str, **kwargs: Any) -> Any:
184194
"""Run on arbitrary text."""
@@ -189,6 +199,67 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
189199
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
190200
return
191201

202+
def _end_flow_span(self, span: Span):
203+
span.set_tags(self._tags)
204+
span.finish()
205+
206+
def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
207+
return self._get_model_token_tags(response, **kwargs)
208+
209+
def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
210+
result = {}
211+
is_get_from_langchain = False
212+
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output[
213+
'token_usage']:
214+
is_get_from_langchain = True
215+
result['input_tokens'] = response.llm_output.get('token_usage', {}).get('prompt_tokens', 0)
216+
result['output_tokens'] = response.llm_output.get('token_usage', {}).get('completion_tokens', 0)
217+
result['tokens'] = result['input_tokens'] + result['output_tokens']
218+
reasoning_tokens = response.llm_output.get('token_usage', {}).get('completion_tokens_details', {}).get(
219+
'reasoning_tokens', 0)
220+
if reasoning_tokens:
221+
result['reasoning_tokens'] = reasoning_tokens
222+
input_cached_tokens = response.llm_output.get('token_usage', {}).get('prompt_tokens_details', {}).get(
223+
'cached_tokens', 0)
224+
if input_cached_tokens:
225+
result['input_cached_tokens'] = input_cached_tokens
226+
elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None:
227+
for i, generation in enumerate(response.generations[0]):
228+
if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
229+
is_get_from_langchain = True
230+
result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0)
231+
result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0)
232+
result['tokens'] = result['input_tokens'] + result['output_tokens']
233+
if generation.message.usage_metadata.get('output_token_details', {}):
234+
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0)
235+
if reasoning_tokens:
236+
result['reasoning_tokens'] = reasoning_tokens
237+
if generation.message.usage_metadata.get('input_token_details', {}):
238+
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0)
239+
if input_read_cached_tokens:
240+
result['input_cached_tokens'] = input_read_cached_tokens
241+
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0)
242+
if input_creation_cached_tokens:
243+
result['input_creation_cached_tokens'] = input_creation_cached_tokens
244+
if is_get_from_langchain:
245+
return result
246+
else:
247+
try:
248+
run_info = self.run_map[str(kwargs['run_id'])]
249+
if run_info is not None and run_info.model_meta is not None:
250+
model_name = run_info.model_meta.model_name
251+
input_messages = run_info.model_meta.message
252+
token_usage = {
253+
'input_tokens': calc_token_usage(input_messages, model_name),
254+
'output_tokens': calc_token_usage(response, model_name),
255+
'tokens': 0
256+
}
257+
token_usage['tokens'] = token_usage['input_tokens'] + token_usage['output_tokens']
258+
return token_usage
259+
except Exception as e:
260+
span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()}
261+
return span_tags
262+
192263
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
193264
# get inputs
194265
params: List[Argument] = []
@@ -234,18 +305,38 @@ def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[
234305
flow_span.set_tags({'prompt_version': kwargs['metadata']['lc_hub_commit_hash']})
235306
flow_span.set_tags({'prompt_provider': 'langsmith'})
236307

237-
def _new_flow_span(self, span_name: str, span_type: str, **kwargs: Any) -> Span:
308+
def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
238309
span_type = _span_type_mapping(span_type)
310+
span_name = node_name
239311
# set parent span
240312
parent_span: Span = None
241313
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
242314
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
315+
# modify name
316+
error_tag = {}
317+
try:
318+
if self.name_fn:
319+
name = self.name_fn(node_name)
320+
if name:
321+
span_name = name
322+
except Exception as e:
323+
error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
243324
# new span
244325
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
245326
run_id = str(kwargs['run_id'])
246327
self.run_map[run_id] = Run(run_id, flow_span, span_type)
247328
# set default tags
248329
flow_span.set_runtime(RuntimeInfo())
330+
# set extra tags
331+
try:
332+
if self.tags_fn:
333+
tags = self.tags_fn(node_name)
334+
if isinstance(tags, dict):
335+
flow_span.set_tags(tags)
336+
except Exception as e:
337+
error_tag = {'error_info': f'tags_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
338+
if error_tag:
339+
flow_span.set_tags(error_tag)
249340
return flow_span
250341

251342
def _get_flow_span(self, **kwargs: Any) -> Span:
@@ -417,7 +508,7 @@ def _convert_inputs(inputs: Any) -> Any:
417508
for each in inputs:
418509
format_inputs.append(_convert_inputs(each))
419510
return format_inputs
420-
if isinstance(inputs, AIMessageChunk):
511+
if isinstance(inputs, (AIMessageChunk, AIMessage)):
421512
"""
422513
Must be before BaseMessage.
423514
"""

0 commit comments

Comments
 (0)