Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing_extensions import Protocol

from langchain.agents.middleware.types import (
AgentMiddleware,
ContextAwareAgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -182,7 +182,7 @@ def _build_cleared_tool_input_message(
)


class ContextEditingMiddleware(AgentMiddleware):
class ContextEditingMiddleware(ContextAwareAgentMiddleware):
"""Automatically prune tool results to manage context size.

The middleware applies a sequence of edits when the total input token count exceeds
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/file_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from langchain_core.tools import tool

from langchain.agents.middleware.types import AgentMiddleware
from langchain.agents.middleware.types import ContextAwareAgentMiddleware


def _expand_include_patterns(pattern: str) -> list[str] | None:
Expand Down Expand Up @@ -84,7 +84,7 @@ def _match_include_pattern(basename: str, pattern: str) -> bool:
return any(fnmatch.fnmatch(basename, candidate) for candidate in expanded)


class FilesystemFileSearchMiddleware(AgentMiddleware):
class FilesystemFileSearchMiddleware(ContextAwareAgentMiddleware):
"""Provides Glob and Grep search over filesystem files.

This middleware adds two tools that search through local filesystem:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langgraph.types import interrupt
from typing_extensions import NotRequired, TypedDict

from langchain.agents.middleware.types import AgentMiddleware, AgentState, ContextT, StateT
from langchain.agents.middleware.types import AgentState, ContextAwareAgentMiddleware, ContextT


class Action(TypedDict):
Expand Down Expand Up @@ -156,7 +156,7 @@ def format_tool_description(
"""JSON schema for the args associated with the action, if edits are allowed."""


class HumanInTheLoopMiddleware(AgentMiddleware[StateT, ContextT]):
class HumanInTheLoopMiddleware(ContextAwareAgentMiddleware):
"""Human in the loop middleware."""

def __init__(
Expand Down
19 changes: 14 additions & 5 deletions libs/langchain_v1/langchain/agents/middleware/model_call_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from langchain_core.messages import AIMessage
from langgraph.channels.untracked_value import UntrackedValue
from langgraph.typing import ContextT
from typing_extensions import NotRequired

from langchain.agents.middleware.types import (
Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(
super().__init__(msg)


class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, ContextT]):
"""Tracks model call counts and enforces limits.
This middleware monitors the number of model calls made during agent execution
Expand Down Expand Up @@ -157,7 +158,11 @@ def __init__(
self.exit_behavior = exit_behavior

@hook_config(can_jump_to=["end"])
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
def before_model(
self,
state: ModelCallLimitState,
runtime: Runtime[ContextT], # noqa: ARG002
) -> dict[str, Any] | None:
"""Check model call limits before making a model call.
Args:
Expand Down Expand Up @@ -203,7 +208,7 @@ def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str
async def abefore_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async check model call limits before making a model call.
Expand All @@ -222,7 +227,11 @@ async def abefore_model(
"""
return self.before_model(state, runtime)

def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
def after_model(
self,
state: ModelCallLimitState,
runtime: Runtime[ContextT], # noqa: ARG002
) -> dict[str, Any] | None:
"""Increment model call counts after a model call.
Args:
Expand All @@ -240,7 +249,7 @@ def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str,
async def aafter_model(
self,
state: ModelCallLimitState,
runtime: Runtime,
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""Async increment model call counts after a model call.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING

from langchain.agents.middleware.types import (
AgentMiddleware,
ContextAwareAgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
Expand All @@ -18,7 +18,7 @@
from langchain_core.language_models.chat_models import BaseChatModel


class ModelFallbackMiddleware(AgentMiddleware):
class ModelFallbackMiddleware(ContextAwareAgentMiddleware):
"""Automatic fallback to alternative models on errors.

Retries failed model calls with alternative models in sequence until
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/model_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
should_retry_exception,
validate_retry_params,
)
from langchain.agents.middleware.types import AgentMiddleware, ModelResponse
from langchain.agents.middleware.types import ContextAwareAgentMiddleware, ModelResponse

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable

from langchain.agents.middleware.types import ModelRequest


class ModelRetryMiddleware(AgentMiddleware):
class ModelRetryMiddleware(ContextAwareAgentMiddleware):
"""Middleware that automatically retries failed model calls with configurable backoff.

Supports retrying on specific exceptions and exponential backoff.
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
detect_mac_address,
detect_url,
)
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
from langchain.agents.middleware.types import AgentState, ContextAwareAgentMiddleware, hook_config

if TYPE_CHECKING:
from collections.abc import Callable

from langgraph.runtime import Runtime


class PIIMiddleware(AgentMiddleware):
class PIIMiddleware(ContextAwareAgentMiddleware):
"""Detect and handle Personally Identifiable Information (PII) in conversations.

This middleware detects common PII types and applies configurable strategies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from langgraph.runtime import Runtime

from langchain.agents.middleware.types import AgentMiddleware, AgentState
from langchain.agents.middleware.types import AgentState, ContextAwareAgentMiddleware
from langchain.chat_models import BaseChatModel, init_chat_model

TokenCounter = Callable[[Iterable[MessageLikeRepresentation]], int]
Expand Down Expand Up @@ -127,7 +127,7 @@ def _get_approximate_token_counter(model: BaseChatModel) -> TokenCounter:
return count_tokens_approximately


class SummarizationMiddleware(AgentMiddleware):
class SummarizationMiddleware(ContextAwareAgentMiddleware):
"""Summarizes conversation history when token limits are approached.

This middleware monitors message token counts and automatically summarizes older
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/todo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing_extensions import NotRequired, TypedDict

from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
ContextAwareAgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -127,7 +127,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
)


class TodoListMiddleware(AgentMiddleware):
class TodoListMiddleware(ContextAwareAgentMiddleware):
"""Middleware that provides todo list management capabilities to agents.

This middleware adds a `write_todos` tool that allows agents to create and manage
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/tool_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
should_retry_exception,
validate_retry_params,
)
from langchain.agents.middleware.types import AgentMiddleware
from langchain.agents.middleware.types import ContextAwareAgentMiddleware

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
Expand All @@ -27,7 +27,7 @@
from langchain.tools import BaseTool


class ToolRetryMiddleware(AgentMiddleware):
class ToolRetryMiddleware(ContextAwareAgentMiddleware):
"""Middleware that automatically retries failed tool calls with configurable backoff.
Supports retrying on specific exceptions and exponential backoff.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing_extensions import TypedDict

from langchain.agents.middleware.types import (
AgentMiddleware,
ContextAwareAgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -85,7 +85,7 @@ def _render_tool_list(tools: list[BaseTool]) -> str:
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)


class LLMToolSelectorMiddleware(AgentMiddleware):
class LLMToolSelectorMiddleware(ContextAwareAgentMiddleware):
"""Uses an LLM to select relevant tools before calling the main model.

When an agent has many tools available, this middleware filters them down
Expand Down
12 changes: 11 additions & 1 deletion libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Awaitable, Callable
from collections.abc import Callable
from dataclasses import dataclass, field, replace
from inspect import iscoroutinefunction
from typing import (
Expand Down Expand Up @@ -47,6 +47,7 @@
__all__ = [
"AgentMiddleware",
"AgentState",
"ContextAwareAgentMiddleware",
"ContextT",
"ModelRequest",
"ModelResponse",
Expand Down Expand Up @@ -688,6 +689,15 @@ async def awrap_tool_call(self, request, handler):
raise NotImplementedError(msg)


class ContextAwareAgentMiddleware(AgentMiddleware[StateT, ContextT]):
"""Base middleware class for agents that access runtime context.
This specialization of `AgentMiddleware` parameterizes the runtime context as
`ContextT`, providing a common base for middleware that needs to read or modify
agent state while also using contextual information from the LangGraph runtime.
"""


class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with `AgentState` and `Runtime` as arguments."""

Expand Down
14 changes: 7 additions & 7 deletions libs/langchain_v1/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.