Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGLOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## [0.1.20] - 2025-11-10
### Added
- langchain callback support langchain V1
- langchain callback support set tag and name

## [0.1.19] - 2025-11-10
### Fixed
- fix baggage escape problem
Expand Down
146 changes: 119 additions & 27 deletions cozeloop/integration/langchain/trace_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import json
import time
import traceback
from typing import List, Dict, Union, Any, Optional
from typing import List, Dict, Union, Any, Optional, Callable, Protocol

import pydantic
from pydantic import Field, BaseModel
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentFinish, AgentAction, LLMResult
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.outputs import LLMResult, ChatGeneration
from langchain_core.agents import AgentFinish, AgentAction
from langchain_core.prompt_values import PromptValue, ChatPromptValue
from langchain_core.messages import BaseMessage, AIMessageChunk, AIMessage
from langchain_core.prompts import AIMessagePromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
Expand All @@ -28,24 +29,40 @@

class LoopTracer:
@classmethod
def get_callback_handler(cls, client: Client = None):
def get_callback_handler(
cls,
client: Client = None,
modify_name_fn: Optional[Callable[[str], str]] = None,
add_tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
tags: Dict[str, Any] = None,
):
"""
Do not hold it for a long time, get a new callback_handler for each request.
modify_name_fn: modify name function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is span name.
add_tags_fn: add tags function, input is node name(if you use langgraph, like add_node(node_name, node_func), it is node name), output is tags dict.
"""
global _trace_callback_client
if client:
_trace_callback_client = client
else:
_trace_callback_client = get_default_client()

return LoopTraceCallbackHandler()
return LoopTraceCallbackHandler(modify_name_fn, add_tags_fn, tags)


class LoopTraceCallbackHandler(BaseCallbackHandler):
def __init__(self):
def __init__(
self,
name_fn: Optional[Callable[[str], str]] = None,
tags_fn: Optional[Callable[[str], Dict[str, Any]]] = None,
tags: Dict[str, Any] = None,
):
super().__init__()
self._space_id = _trace_callback_client.workspace_id
self.run_map: Dict[str, Run] = {}
self.name_fn = name_fn
self.tags_fn = tags_fn
self._tags = tags if tags else {}

def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> Any:
span_tags = {}
Expand Down Expand Up @@ -97,32 +114,26 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
try:
# set output span_tag
flow_span.set_tags({'output': ModelTraceOutput(response.generations).to_json()})
# set model tags
tags = self._get_model_tags(response, **kwargs)
if tags:
self._set_span_tags(flow_span, tags, need_convert_tag_value=False)
except Exception as e:
flow_span.set_error(e)
# calculate token usage,and set span_tag
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output['token_usage']:
self._set_span_tags(flow_span, response.llm_output['token_usage'], need_convert_tag_value=False)
else:
try:
run_info = self.run_map[str(kwargs['run_id'])]
if run_info is not None and run_info.model_meta is not None:
model_name = run_info.model_meta.model_name
input_messages = run_info.model_meta.message
flow_span.set_input_tokens(calc_token_usage(input_messages, model_name))
flow_span.set_output_tokens(calc_token_usage(response, model_name))
except Exception as e:
flow_span.set_error(e)
# finish flow_span
flow_span.finish()
self._end_flow_span(flow_span)

def on_chain_start(self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any) -> Any:
flow_span = None
try:
if kwargs.get('run_type', '') == 'prompt' or kwargs.get('name', '') == 'ChatPromptTemplate':
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
flow_span = self._new_flow_span(kwargs['name'], 'prompt', **kwargs)
self._on_prompt_start(flow_span, serialized, inputs, **kwargs)
else:
flow_span = self._new_flow_span(kwargs['name'], kwargs['name'], **kwargs)
span_type = 'chain'
if kwargs['name'] == 'LangGraph': # LangGraph is agent span_type,for trajectory evaluation aggregate to an agent
span_type = 'agent'
flow_span = self._new_flow_span(kwargs['name'], span_type, **kwargs)
flow_span.set_tags({'input': _convert_2_json(inputs)})
except Exception as e:
if flow_span is not None:
Expand All @@ -141,7 +152,7 @@ def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> An
flow_span.set_tags({'output': _convert_2_json(outputs)})
except Exception as e:
flow_span.set_error(e)
flow_span.finish()
self._end_flow_span(flow_span)

def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
flow_span = self._get_flow_span(**kwargs)
Expand All @@ -150,7 +161,7 @@ def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: A
flow_span = self._new_flow_span(span_name, 'chain_error', **kwargs)
flow_span.set_error(error)
flow_span.set_tags({'error_trace': traceback.format_exc()})
flow_span.finish()
self._end_flow_span(flow_span)

def on_tool_start(
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
Expand All @@ -166,7 +177,7 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
flow_span.set_tags({'output': _convert_2_json(output)})
except Exception as e:
flow_span.set_error(e)
flow_span.finish()
self._end_flow_span(flow_span)

def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
Expand All @@ -177,7 +188,7 @@ def on_tool_error(
flow_span = self._new_flow_span(span_name, 'tool_error', **kwargs)
flow_span.set_error(error)
flow_span.set_tags({'error_trace': traceback.format_exc()})
flow_span.finish()
self._end_flow_span(flow_span)

def on_text(self, text: str, **kwargs: Any) -> Any:
"""Run on arbitrary text."""
Expand All @@ -188,6 +199,67 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
return

def _end_flow_span(self, span: Span):
span.set_tags(self._tags)
span.finish()

def _get_model_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
return self._get_model_token_tags(response, **kwargs)

def _get_model_token_tags(self, response: LLMResult, **kwargs: Any) -> Dict[str, Any]:
result = {}
is_get_from_langchain = False
if response.llm_output is not None and 'token_usage' in response.llm_output and response.llm_output[
'token_usage']:
is_get_from_langchain = True
result['input_tokens'] = response.llm_output.get('token_usage', {}).get('prompt_tokens', 0)
result['output_tokens'] = response.llm_output.get('token_usage', {}).get('completion_tokens', 0)
result['tokens'] = result['input_tokens'] + result['output_tokens']
reasoning_tokens = response.llm_output.get('token_usage', {}).get('completion_tokens_details', {}).get(
'reasoning_tokens', 0)
if reasoning_tokens:
result['reasoning_tokens'] = reasoning_tokens
input_cached_tokens = response.llm_output.get('token_usage', {}).get('prompt_tokens_details', {}).get(
'cached_tokens', 0)
if input_cached_tokens:
result['input_cached_tokens'] = input_cached_tokens
elif response.generations is not None and len(response.generations) > 0 and response.generations[0] is not None:
for i, generation in enumerate(response.generations[0]):
if isinstance(generation, ChatGeneration) and isinstance(generation.message,(AIMessageChunk, AIMessage)) and generation.message.usage_metadata:
is_get_from_langchain = True
result['input_tokens'] = generation.message.usage_metadata.get('input_tokens', 0)
result['output_tokens'] = generation.message.usage_metadata.get('output_tokens', 0)
result['tokens'] = result['input_tokens'] + result['output_tokens']
if generation.message.usage_metadata.get('output_token_details', {}):
reasoning_tokens = generation.message.usage_metadata.get('output_token_details', {}).get('reasoning', 0)
if reasoning_tokens:
result['reasoning_tokens'] = reasoning_tokens
if generation.message.usage_metadata.get('input_token_details', {}):
input_read_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_read', 0)
if input_read_cached_tokens:
result['input_cached_tokens'] = input_read_cached_tokens
input_creation_cached_tokens = generation.message.usage_metadata.get('input_token_details', {}).get('cache_creation', 0)
if input_creation_cached_tokens:
result['input_creation_cached_tokens'] = input_creation_cached_tokens
if is_get_from_langchain:
return result
else:
try:
run_info = self.run_map[str(kwargs['run_id'])]
if run_info is not None and run_info.model_meta is not None:
model_name = run_info.model_meta.model_name
input_messages = run_info.model_meta.message
token_usage = {
'input_tokens': calc_token_usage(input_messages, model_name),
'output_tokens': calc_token_usage(response, model_name),
'tokens': 0
}
token_usage['tokens'] = token_usage['input_tokens'] + token_usage['output_tokens']
return token_usage
except Exception as e:
span_tags = {'error_info': repr(e), 'error_trace': traceback.format_exc()}
return span_tags

def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[str, Any], str), **kwargs: Any) -> None:
# get inputs
params: List[Argument] = []
Expand Down Expand Up @@ -233,18 +305,38 @@ def _on_prompt_start(self, flow_span, serialized: Dict[str, Any], inputs: (Dict[
flow_span.set_tags({'prompt_version': kwargs['metadata']['lc_hub_commit_hash']})
flow_span.set_tags({'prompt_provider': 'langsmith'})

def _new_flow_span(self, span_name: str, span_type: str, **kwargs: Any) -> Span:
def _new_flow_span(self, node_name: str, span_type: str, **kwargs: Any) -> Span:
span_type = _span_type_mapping(span_type)
span_name = node_name
# set parent span
parent_span: Span = None
if 'parent_run_id' in kwargs and kwargs['parent_run_id'] is not None and str(kwargs['parent_run_id']) in self.run_map:
parent_span = self.run_map[str(kwargs['parent_run_id'])].span
# modify name
error_tag = {}
try:
if self.name_fn:
name = self.name_fn(node_name)
if name:
span_name = name
except Exception as e:
error_tag = {'error_info': f'name_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
# new span
flow_span = _trace_callback_client.start_span(span_name, span_type, child_of=parent_span)
run_id = str(kwargs['run_id'])
self.run_map[run_id] = Run(run_id, flow_span, span_type)
# set default tags
flow_span.set_runtime(RuntimeInfo())
# set extra tags
try:
if self.tags_fn:
tags = self.tags_fn(node_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tags_fn的优先级,是不是应该比tags高一点更合理?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯,合理,我改下

if isinstance(tags, dict):
flow_span.set_tags(tags)
except Exception as e:
error_tag = {'error_info': f'tags_fn error {repr(e)}', 'error_trace': traceback.format_exc()}
if error_tag:
flow_span.set_tags(error_tag)
return flow_span

def _get_flow_span(self, **kwargs: Any) -> Span:
Expand Down
Loading