Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
trajectory = self.truncate_trajectory(trajectory)
raise ValueError(
"The context window was exceeded even after 3 attempts to truncate the trajectory."
)

async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
for _ in range(3):
Expand All @@ -164,6 +167,9 @@ async def _async_call_with_potential_trajectory_truncation(self, module, traject
except ContextWindowExceededError:
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
trajectory = self.truncate_trajectory(trajectory)
raise ValueError(
"The context window was exceeded even after 3 attempts to truncate the trajectory."
)

def truncate_trajectory(self, trajectory):
"""Truncates the trajectory so that it fits in the context window.
Expand Down
33 changes: 33 additions & 0 deletions tests/predict/test_react.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,39 @@ def mock_react(**kwargs):
assert result.output_text == "Final output"


@pytest.mark.asyncio
async def test_context_window_exceeded_after_retries():
def echo(text: str) -> str:
return f"Echoed: {text}"

react = dspy.ReAct("input_text -> output_text", tools=[echo])

def mock_react(**kwargs):
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")

react.react = mock_react
react.extract = lambda **kwargs: dspy.Prediction(output_text="Fallback output")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can we verify what input is passed to extract?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added verification that extract receives the correct inputs (input_text and trajectory). The test now captures the kwargs passed to extract and asserts they contain the expected values. Fixed in 18e999f.


# Test sync version
result = react(input_text="test input")
assert result.trajectory == {}
assert result.output_text == "Fallback output"

# Test async version
async def mock_react_async(**kwargs):
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")

async def mock_extract_async(**kwargs):
return dspy.Prediction(output_text="Fallback output")

react.react.acall = mock_react_async
react.extract.acall = mock_extract_async

result = await react.acall(input_text="test input")
assert result.trajectory == {}
assert result.output_text == "Fallback output"


def test_error_retry():
# --- a tiny tool that always fails -------------------------------------
def foo(a, b):
Expand Down