Skip to content

Commit 3030ffc

Browse files
fix: simplify summarization cutoff logic (#34195)
This PR changes how we find the cutoff for summarization, summarizing content more eagerly if the initial cutoff point isn't safe (ie, would break apart AI + tool message pairs) This new algorithm is quite simple - it looks at the initial cutoff point, if it's not safe, moves forward through the message list until it finds the first non tool message. For example: ``` H AI TM --- theoretical cutoff based keep=('messages', 3) TM AI TM ``` ``` H AI TM TM --- actual cutoff, more aggressive summarization AI TM ```
1 parent 1ad9de4 commit 3030ffc

File tree

2 files changed

+127
-198
lines changed

2 files changed

+127
-198
lines changed

libs/langchain_v1/langchain/agents/middleware/summarization.py

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import Any, Literal, cast
88

99
from langchain_core.messages import (
10-
AIMessage,
1110
AnyMessage,
1211
MessageLikeRepresentation,
1312
RemoveMessage,
@@ -56,7 +55,6 @@
5655
_DEFAULT_MESSAGES_TO_KEEP = 20
5756
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
5857
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
59-
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
6058

6159
ContextFraction = tuple[Literal["fraction"], float]
6260
"""Fraction of model's maximum input tokens.
@@ -397,11 +395,8 @@ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
397395
return 0
398396
cutoff_candidate = len(messages) - 1
399397

400-
for i in range(cutoff_candidate, -1, -1):
401-
if self._is_safe_cutoff_point(messages, i):
402-
return i
403-
404-
return 0
398+
# Advance past any ToolMessages to avoid splitting AI/Tool pairs
399+
return self._find_safe_cutoff_point(messages, cutoff_candidate)
405400

406401
def _get_profile_limits(self) -> int | None:
407402
"""Retrieve max input token limit from the model profile."""
@@ -463,67 +458,26 @@ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -
463458
464459
Returns the index where messages can be safely cut without separating
465460
related AI and Tool messages. Returns `0` if no safe cutoff is found.
461+
462+
This is aggressive with summarization - if the target cutoff lands in the
463+
middle of tool messages, we advance past all of them (summarizing more).
466464
"""
467465
if len(messages) <= messages_to_keep:
468466
return 0
469467

470468
target_cutoff = len(messages) - messages_to_keep
469+
return self._find_safe_cutoff_point(messages, target_cutoff)
471470

472-
for i in range(target_cutoff, -1, -1):
473-
if self._is_safe_cutoff_point(messages, i):
474-
return i
475-
476-
return 0
477-
478-
def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool:
479-
"""Check if cutting at index would separate AI/Tool message pairs."""
480-
if cutoff_index >= len(messages):
481-
return True
482-
483-
search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS)
484-
search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS)
485-
486-
for i in range(search_start, search_end):
487-
if not self._has_tool_calls(messages[i]):
488-
continue
489-
490-
tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i]))
491-
if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids):
492-
return False
493-
494-
return True
471+
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
472+
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.
495473
496-
def _has_tool_calls(self, message: AnyMessage) -> bool:
497-
"""Check if message is an AI message with tool calls."""
498-
return (
499-
isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value]
500-
)
501-
502-
def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]:
503-
"""Extract tool call IDs from an AI message."""
504-
tool_call_ids = set()
505-
for tc in ai_message.tool_calls:
506-
call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None)
507-
if call_id is not None:
508-
tool_call_ids.add(call_id)
509-
return tool_call_ids
510-
511-
def _cutoff_separates_tool_pair(
512-
self,
513-
messages: list[AnyMessage],
514-
ai_message_index: int,
515-
cutoff_index: int,
516-
tool_call_ids: set[str],
517-
) -> bool:
518-
"""Check if cutoff separates an AI message from its corresponding tool messages."""
519-
for j in range(ai_message_index + 1, len(messages)):
520-
message = messages[j]
521-
if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids:
522-
ai_before_cutoff = ai_message_index < cutoff_index
523-
tool_before_cutoff = j < cutoff_index
524-
if ai_before_cutoff != tool_before_cutoff:
525-
return True
526-
return False
474+
If the message at cutoff_index is a ToolMessage, advance until we find
475+
a non-ToolMessage. This ensures we never cut in the middle of parallel
476+
tool call responses.
477+
"""
478+
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
479+
cutoff_index += 1
480+
return cutoff_index
527481

528482
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
529483
"""Generate summary for the given messages."""

0 commit comments

Comments
 (0)