Skip to content
21 changes: 14 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ dependencies = [
"pydantic>=2.12.5",
"requests>=2.32.5",
"urllib3>=2.0.0",
"langchain>=0.3.27",
"langchain-core>=1.2.7",
"langchain-openai>=0.3.35",
"langchain-anthropic>=0.3.22",
"langchain-google-genai>=4.2.0",
"langchain-xai>=0.2.5",
"langchain>=1.2.15",
"langchain-core>=1.2.26",
"langchain-openai>=1.1.12",
"langchain-anthropic>=1.4.0",
"langchain-google-genai>=4.2.1",
"langchain-xai>=1.2.2",
"openai>=2.15.0",
"anthropic>=0.76.0",
"google-generativeai>=0.8.6",
Expand All @@ -49,7 +49,14 @@ test = [
]

[tool.ruff]
exclude = [".claude/worktrees"]
exclude = [
".claude/worktrees",
".venv",
"venv",
"env",
"scripts/attestation",
"**/site-packages",
]

[tool.uv]
# Pre-release needed for og-test-v2-x402==0.0.11.dev5
Expand Down
96 changes: 90 additions & 6 deletions tee_gateway/controllers/chat_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ChatCompletionRequestFunctionMessage,
)

from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, SystemMessage

from tee_gateway.tee_manager import get_tee_keys, compute_tee_msg_hash
from tee_gateway.llm_backend import (
Expand Down Expand Up @@ -94,12 +94,45 @@ def _invoke_anthropic_structured(
schema_def = {**schema_def, "title": name}

structured = model.with_structured_output(
schema_def, method="json_schema", strict=strict
schema_def, method="json_schema", strict=strict, include_raw=True
)
result = structured.invoke(langchain_messages)

content_str = json.dumps(result) if isinstance(result, dict) else str(result)
return AIMessage(content=content_str)
# include_raw=True returns {"raw": AIMessage, "parsed": dict, "parsing_error": ...}
if isinstance(result, dict) and result.get("parsing_error"):
raise ValueError(f"Structured output parsing failed: {result['parsing_error']}")
raw_msg = result.get("raw") if isinstance(result, dict) else None
parsed = result.get("parsed") if isinstance(result, dict) else result
content_str = json.dumps(parsed) if isinstance(parsed, dict) else str(parsed)

# Preserve usage_metadata from the raw Anthropic response so the x402
# cost calculator can extract token counts from the final response body.
msg = AIMessage(content=content_str)
if (
raw_msg is not None
and hasattr(raw_msg, "usage_metadata")
and raw_msg.usage_metadata
):
msg.usage_metadata = raw_msg.usage_metadata
return msg


def _messages_contain_json_word(messages: list) -> bool:
"""Return True if any message content contains the word 'json' (case-insensitive).

Handles both plain-string content and list-of-parts content (multimodal messages).
"""
for m in messages:
content = getattr(m, "content", "")
if isinstance(content, str):
if "json" in content.lower():
return True
elif isinstance(content, list):
for part in content:
text = part.get("text", "") if isinstance(part, dict) else str(part)
if "json" in text.lower():
return True
return False


def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
Expand Down Expand Up @@ -147,6 +180,16 @@ def _create_non_streaming_response(chat_request: CreateChatCompletionRequest):
model = model.bind(response_format=rf_dict)

langchain_messages = convert_messages(chat_request.messages)

# OpenAI (and compatible providers) require the word "json" to appear
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure we really need this. this class works and does not need to inject anything https://github.com/OpenGradient/memsync/blob/main/memsync/llms/openai.py#L99

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only the case when we use json_object

E.g. we get this error message

openai.BadRequestError: Error code: 400 - {'error': {'message': "'messages' must contain the word 'json' in some form, to use 'response_format' of type 'json_object'.", 'type': 'invalid_request_error', 'param': 'messages', 'code': None}}

mem0ai/mem0#4248 -- seems like a known requirement.

# somewhere in the messages when response_format.type == "json_object".
# Inject a brief system instruction if none of the messages satisfy this.
if rf_dict and rf_dict.get("type") == "json_object":
if not _messages_contain_json_word(langchain_messages):
langchain_messages = [
SystemMessage(content="Respond in JSON format.")
] + langchain_messages
Comment thread
kylexqian marked this conversation as resolved.

if rf_dict and get_provider_from_model(chat_request.model) == "anthropic":
response = _invoke_anthropic_structured(model, rf_dict, langchain_messages)
else:
Expand Down Expand Up @@ -277,6 +320,22 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest):
model = model.bind(response_format=rf)

langchain_messages = convert_messages(chat_request.messages)

# OpenAI (and compatible providers) require the word "json" to appear
# somewhere in the messages when response_format.type == "json_object".
# Inject a brief system instruction if none of the messages satisfy this.
# `rf` is defined inside the `if chat_request.response_format` block above;
# guard with the same condition before accessing it.
if (
chat_request.response_format
and anthropic_structured_rf is None
and rf.get("type") == "json_object"
):
if not _messages_contain_json_word(langchain_messages):
langchain_messages = [
SystemMessage(content="Respond in JSON format.")
] + langchain_messages

tee_keys = get_tee_keys()

# For Anthropic structured output, with_structured_output() invokes
Expand All @@ -292,8 +351,21 @@ def _create_streaming_response(chat_request: CreateChatCompletionRequest):
if isinstance(ai_msg.content, str)
else json.dumps(ai_msg.content)
)
# Capture usage now — the streaming loop never runs for this path,
# so final_usage would otherwise stay None and x402 cannot charge.
anthropic_structured_usage: dict[str, int] | None = None

if hasattr(ai_msg, "usage_metadata") and ai_msg.usage_metadata:
um = ai_msg.usage_metadata

anthropic_structured_usage = {
"input_tokens": um.get("input_tokens", 0),
"output_tokens": um.get("output_tokens", 0),
"total_tokens": um.get("total_tokens", 0),
}
else:
anthropic_structured_content = None
anthropic_structured_usage = None

def generate():
full_content = ""
Expand All @@ -303,8 +375,11 @@ def generate():

try:
if anthropic_structured_content is not None:
# Emit the pre-computed structured result as a single chunk
# Emit the pre-computed structured result as a single chunk.
# Seed final_usage from the synchronous invoke so the final
# SSE chunk carries a 'usage' dict for the x402 cost calculator.
full_content = anthropic_structured_content
final_usage = anthropic_structured_usage
data = {
"choices": [
{
Expand Down Expand Up @@ -435,8 +510,17 @@ def generate():
yield f"data: {json.dumps(data)}\n\n"

# --- Usage metadata ---
# Accumulate deltas rather than replacing: Gemini returns cumulative
# usageMetadata on every chunk and LangChain emits deltas via
# subtract_usage(), so input_tokens only appears non-zero in the
# *first* chunk carrying usage. Replacing on each chunk would
# overwrite that value with 0 from all subsequent chunks.
if hasattr(chunk, "usage_metadata") and chunk.usage_metadata:
final_usage = chunk.usage_metadata
if final_usage is None:
final_usage = {}
for k, v in chunk.usage_metadata.items():
if isinstance(v, (int, float)):
final_usage[k] = final_usage.get(k, 0) + v

# Flush buffered tool calls for OpenAI/Anthropic
if buffer_tool_calls and buffered_tool_calls:
Expand Down
Loading
Loading