|
7 | 7 | from typing import Any, Literal, cast |
8 | 8 |
|
9 | 9 | from langchain_core.messages import ( |
10 | | - AIMessage, |
11 | 10 | AnyMessage, |
12 | 11 | MessageLikeRepresentation, |
13 | 12 | RemoveMessage, |
|
56 | 55 | _DEFAULT_MESSAGES_TO_KEEP = 20 |
57 | 56 | _DEFAULT_TRIM_TOKEN_LIMIT = 4000 |
58 | 57 | _DEFAULT_FALLBACK_MESSAGE_COUNT = 15 |
59 | | -_SEARCH_RANGE_FOR_TOOL_PAIRS = 5 |
60 | 58 |
|
61 | 59 | ContextFraction = tuple[Literal["fraction"], float] |
62 | 60 | """Fraction of model's maximum input tokens. |
@@ -397,11 +395,8 @@ def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None: |
397 | 395 | return 0 |
398 | 396 | cutoff_candidate = len(messages) - 1 |
399 | 397 |
|
400 | | - for i in range(cutoff_candidate, -1, -1): |
401 | | - if self._is_safe_cutoff_point(messages, i): |
402 | | - return i |
403 | | - |
404 | | - return 0 |
| 398 | + # Advance past any ToolMessages to avoid splitting AI/Tool pairs |
| 399 | + return self._find_safe_cutoff_point(messages, cutoff_candidate) |
405 | 400 |
|
406 | 401 | def _get_profile_limits(self) -> int | None: |
407 | 402 | """Retrieve max input token limit from the model profile.""" |
@@ -463,67 +458,26 @@ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) - |
463 | 458 |
|
464 | 459 | Returns the index where messages can be safely cut without separating |
465 | 460 | related AI and Tool messages. Returns `0` if no safe cutoff is found. |
| 461 | +
|
| 462 | + This is aggressive with summarization - if the target cutoff lands in the |
| 463 | + middle of tool messages, we advance past all of them (summarizing more). |
466 | 464 | """ |
467 | 465 | if len(messages) <= messages_to_keep: |
468 | 466 | return 0 |
469 | 467 |
|
470 | 468 | target_cutoff = len(messages) - messages_to_keep |
| 469 | + return self._find_safe_cutoff_point(messages, target_cutoff) |
471 | 470 |
|
472 | | - for i in range(target_cutoff, -1, -1): |
473 | | - if self._is_safe_cutoff_point(messages, i): |
474 | | - return i |
475 | | - |
476 | | - return 0 |
477 | | - |
478 | | - def _is_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> bool: |
479 | | - """Check if cutting at index would separate AI/Tool message pairs.""" |
480 | | - if cutoff_index >= len(messages): |
481 | | - return True |
482 | | - |
483 | | - search_start = max(0, cutoff_index - _SEARCH_RANGE_FOR_TOOL_PAIRS) |
484 | | - search_end = min(len(messages), cutoff_index + _SEARCH_RANGE_FOR_TOOL_PAIRS) |
485 | | - |
486 | | - for i in range(search_start, search_end): |
487 | | - if not self._has_tool_calls(messages[i]): |
488 | | - continue |
489 | | - |
490 | | - tool_call_ids = self._extract_tool_call_ids(cast("AIMessage", messages[i])) |
491 | | - if self._cutoff_separates_tool_pair(messages, i, cutoff_index, tool_call_ids): |
492 | | - return False |
493 | | - |
494 | | - return True |
| 471 | + def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int: |
| 472 | + """Find a safe cutoff point that doesn't split AI/Tool message pairs. |
495 | 473 |
|
496 | | - def _has_tool_calls(self, message: AnyMessage) -> bool: |
497 | | - """Check if message is an AI message with tool calls.""" |
498 | | - return ( |
499 | | - isinstance(message, AIMessage) and hasattr(message, "tool_calls") and message.tool_calls # type: ignore[return-value] |
500 | | - ) |
501 | | - |
502 | | - def _extract_tool_call_ids(self, ai_message: AIMessage) -> set[str]: |
503 | | - """Extract tool call IDs from an AI message.""" |
504 | | - tool_call_ids = set() |
505 | | - for tc in ai_message.tool_calls: |
506 | | - call_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) |
507 | | - if call_id is not None: |
508 | | - tool_call_ids.add(call_id) |
509 | | - return tool_call_ids |
510 | | - |
511 | | - def _cutoff_separates_tool_pair( |
512 | | - self, |
513 | | - messages: list[AnyMessage], |
514 | | - ai_message_index: int, |
515 | | - cutoff_index: int, |
516 | | - tool_call_ids: set[str], |
517 | | - ) -> bool: |
518 | | - """Check if cutoff separates an AI message from its corresponding tool messages.""" |
519 | | - for j in range(ai_message_index + 1, len(messages)): |
520 | | - message = messages[j] |
521 | | - if isinstance(message, ToolMessage) and message.tool_call_id in tool_call_ids: |
522 | | - ai_before_cutoff = ai_message_index < cutoff_index |
523 | | - tool_before_cutoff = j < cutoff_index |
524 | | - if ai_before_cutoff != tool_before_cutoff: |
525 | | - return True |
526 | | - return False |
| 474 | + If the message at cutoff_index is a ToolMessage, advance until we find |
| 475 | + a non-ToolMessage. This ensures we never cut in the middle of parallel |
| 476 | + tool call responses. |
| 477 | + """ |
| 478 | + while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage): |
| 479 | + cutoff_index += 1 |
| 480 | + return cutoff_index |
527 | 481 |
|
528 | 482 | def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: |
529 | 483 | """Generate summary for the given messages.""" |
|
0 commit comments