Skip to content

Commit e104a2e

Browse files
authored
Feat langchain v1 (#38)
* support langchian v1
1 parent 35e5f47 commit e104a2e

File tree

7 files changed

+201
-48
lines changed

7 files changed

+201
-48
lines changed

CHANGLOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## [0.1.20] - 2025-11-10
2+
### Added
3+
- langchain callback support langchain V1
4+
- langchain callback support set tag and name
5+
16
## [0.1.19] - 2025-11-10
27
### Fixed
38
- fix baggage escape problem

cozeloop/integration/langchain/trace_callback.py

Lines changed: 119 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import json
66
import time
77
import traceback
8-
from typing import List, Dict, Union, Any, Optional
8+
from typing import List, Dict, Union, Any, Optional, Callable, Protocol
99

1010
import pydantic
1111
from pydantic import Field, BaseModel
12-
from langchain.callbacks.base import BaseCallbackHandler
13-
from langchain.schema import AgentFinish, AgentAction, LLMResult
12+
from langchain_core.callbacks.base import BaseCallbackHandler
13+
from langchain_core.outputs import LLMResult, ChatGeneration
14+
from langchain_core.agents import AgentFinish, AgentAction
1415
from langchain_core.prompt_values import PromptValue, ChatPromptValue
1516
from langchain_core.messages import BaseMessage, AIMessageChunk, AIMessage
1617
from langchain_core.prompts import AIMessagePromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
@@ -28,24 +29,40 @@
2829

2930
class LoopTracer:
3031
@classmethod
31-
def get_callback_handler(cls, client: Client = None):
32+
def get_callback_handler(
33+
cls,
34+
client: Client = None,
35+
modify_name_fn: Optional[Callable[[str], str]] = None,
36+
add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
37+
tags: Dict[str, Any] = None,
38+
):
3239
"""
3340
Do not hold it for a long time, get a new callback_handler for each request.
41+
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
42+
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
3443
"""
3544
global _trace_callback_client
3645
if client:
3746
_trace_callback_client = client
3847
else:
3948
_trace_callback_client = get_default_client()
4049

41-
return LoopTraceCallbackHandler()
50+
return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags)
4251

4352

4453
class LoopTraceCallbackHandler(BaseCallbackHandler):
45-
def __init__(self):
54+
def __init__(
55+
self,
56+
name_fn: Optional[Callable[[str], str]] = None,
57+
tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
58+
tags: Dict[str, Any] = None,
59+
):
4660
super().__init__()
4761
self._space_id = _trace_callback_client.workspace_id
4862
self.run_map: Dict[str, Run] = {}
63+
self.name_fn = name_fn
64+
self.tags_fn = tags_fn
65+
self._tags = tags if tags else {}
4966

5067
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
5168
span_tags = {}
@@ -97,32 +114,26 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
97114
try:
98115
# set output span_tag
99116
flow_span.set_tags({'output': ModelTraceOutput(response.generations).to_json()})
117+
# set model tags
118+
tags = self._get_model_tags(response, **kwargs)
119+
if tags:
120+
self._set_span_tags(flow_span, tags, need_convert_tag_value=False)
100121
except Exception as e:
101122
flow_span.set_error(e)
102-
# calculate token usage,and set span_tag
103-
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output['token_usage']:
104-
self._set_span_tags(flow_span, response.llm_output['token_usage'], need_convert_tag_value=False)
105-
else:
106-
try:
107-
run_info = self.run_map[str(kwargs['run_id'])]
108-
if run_info is not None and run_info.model_meta is not None:
109-
model_name = run_info.model_meta.model_name
110-
input_messages = run_info.model_meta.message
111-
flow_span.set_input_tokens(calc_token_usage(input_messages, model_name))
112-
flow_span.set_output_tokens(calc_token_usage(response, model_name))
113-
except Exception as e:
114-
flow_span.set_error(e)
115123
# finish flow_span
116-
flow_span.finish()
124+
self._end_flow_span(flow_span)
117125

118126
def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
119127
flow_span = None
120128
try:
121129
if kwargs.get('run_type', '') == 'prompt' or kwargs.get('name', '') == 'ChatPromptTemplate':
122-
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
130+
flow_span = self._new_flow_span(kwargs['name'], 'prompt', **kwargs)
123131
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
124132
else:
125-
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
133+
span_type = 'chain'
134+
if kwargs['name'] == 'LangGraph': # LangGraph is agent span_type,for trajectory evaluation aggregate to an agent
135+
span_type = 'agent'
136+
flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs)
126137
flow_span.set_tags({'input': _convert_2_json(inputs)})
127138
except Exception as e:
128139
if flow_span is not None:
@@ -141,7 +152,7 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
141152
flow_span.set_tags({'output': _convert_2_json(outputs)})
142153
except Exception as e:
143154
flow_span.set_error(e)
144-
flow_span.finish()
155+
self._end_flow_span(flow_span)
145156

146157
def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
147158
flow_span = self._get_flow_span(**kwargs)
@@ -150,7 +161,7 @@ def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: A
150161
flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs)
151162
flow_span.set_error(error)
152163
flow_span.set_tags({'error_trace': traceback.format_exc()})
153-
flow_span.finish()
164+
self._end_flow_span(flow_span)
154165

155166
def on_tool_start(
156167
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
@@ -166,7 +177,7 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
166177
flow_span.set_tags({'output': _convert_2_json(output)})
167178
except Exception as e:
168179
flow_span.set_error(e)
169-
flow_span.finish()
180+
self._end_flow_span(flow_span)
170181

171182
def on_tool_error(
172183
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
@@ -177,7 +188,7 @@ def on_tool_error(
177188
flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs)
178189
flow_span.set_error(error)
179190
flow_span.set_tags({'error_trace': traceback.format_exc()})
180-
flow_span.finish()
191+
self._end_flow_span(flow_span)
181192

182193
def on_text(self, text: str, **kwargs: Any) -> Any:
183194
"""Run on arbitrary text."""
@@ -188,6 +199,66 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
188199
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
189200
return
190201

202+
def _end_flow_span(self, span: Span):
203+
span.finish()
204+
205+
def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
206+
return self._get_model_token_tags(response, **kwargs)
207+
208+
def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
209+
result = {}
210+
is_get_from_langchain = False
211+
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output[
212+
'token_usage']:
213+
is_get_from_langchain = True
214+
result['input_tokens'] = response.llm_output.get('token_usage', {}).get('prompt_tokens', 0)
215+
result['output_tokens'] = response.llm_output.get('token_usage', {}).get('completion_tokens', 0)
216+
result['tokens'] = result['input_tokens'] + result['output_tokens']
217+
reasoning_tokens = response.llm_output.get('token_usage', {}).get('completion_tokens_details', {}).get(
218+
'reasoning_tokens', 0)
219+
if reasoning_tokens:
220+
result['reasoning_tokens'] = reasoning_tokens
221+
input_cached_tokens = response.llm_output.get('token_usage', {}).get('prompt_tokens_details', {}).get(
222+
'cached_tokens', 0)
223+
if input_cached_tokens:
224+
result['input_cached_tokens'] = input_cached_tokens
225+
elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None:
226+
for i, generation in enumerate(response.generations[0]):
227+
if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
228+
is_get_from_langchain = True
229+
result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0)
230+
result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0)
231+
result['tokens'] = result['input_tokens'] + result['output_tokens']
232+
if generation.message.usage_metadata.get('output_token_details', {}):
233+
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0)
234+
if reasoning_tokens:
235+
result['reasoning_tokens'] = reasoning_tokens
236+
if generation.message.usage_metadata.get('input_token_details', {}):
237+
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0)
238+
if input_read_cached_tokens:
239+
result['input_cached_tokens'] = input_read_cached_tokens
240+
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0)
241+
if input_creation_cached_tokens:
242+
result['input_creation_cached_tokens'] = input_creation_cached_tokens
243+
if is_get_from_langchain:
244+
return result
245+
else:
246+
try:
247+
run_info = self.run_map[str(kwargs['run_id'])]
248+
if run_info is not None and run_info.model_meta is not None:
249+
model_name = run_info.model_meta.model_name
250+
input_messages = run_info.model_meta.message
251+
token_usage = {
252+
'input_tokens': calc_token_usage(input_messages, model_name),
253+
'output_tokens': calc_token_usage(response, model_name),
254+
'tokens': 0
255+
}
256+
token_usage['tokens'] = token_usage['input_tokens'] + token_usage['output_tokens']
257+
return token_usage
258+
except Exception as e:
259+
span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()}
260+
return span_tags
261+
191262
def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
192263
# get inputs
193264
params: List[Argument] = []
@@ -233,18 +304,39 @@ def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[
233304
flow_span.set_tags({'prompt_version': kwargs['metadata']['lc_hub_commit_hash']})
234305
flow_span.set_tags({'prompt_provider': 'langsmith'})
235306

236-
def _new_flow_span(self, span_name: str, span_type: str, **kwargs: Any) -> Span:
307+
def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
237308
span_type = _span_type_mapping(span_type)
309+
span_name = node_name
238310
# set parent span
239311
parent_span: Span = None
240312
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
241313
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
314+
# modify name
315+
error_tag = {}
316+
try:
317+
if self.name_fn:
318+
name = self.name_fn(node_name)
319+
if name:
320+
span_name = name
321+
except Exception as e:
322+
error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
242323
# new span
243324
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
244325
run_id = str(kwargs['run_id'])
245326
self.run_map[run_id] = Run(run_id, flow_span, span_type)
246327
# set default tags
247328
flow_span.set_runtime(RuntimeInfo())
329+
# set extra tags
330+
flow_span.set_tags(self._tags) # global tags
331+
try:
332+
if self.tags_fn: # add tags fn
333+
tags = self.tags_fn(node_name)
334+
if isinstance(tags, dict):
335+
flow_span.set_tags(tags)
336+
except Exception as e:
337+
error_tag = {'error_info': f'tags_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
338+
if error_tag:
339+
flow_span.set_tags(error_tag)
248340
return flow_span
249341

250342
def _get_flow_span(self, **kwargs: Any) -> Span:

0 commit comments

Comments
 (0)