Skip to content

Commit b95e101

Browse files
CopilotTomeHirata
andauthored
Fix ContextWindowExceededError after 3 retries in react loop (#9110)
* Initial plan * Fix ContextWindowExceededError after 3 retries in react loop When _call_with_potential_trajectory_truncation exhausts all retry attempts, it now raises a clear ValueError instead of returning None. This prevents the AttributeError: 'NoneType' object has no attribute 'next_thought' that occurred when accessing properties on the None return value. The ValueError is caught in forward/aforward and causes the loop to break gracefully, allowing the extract phase to proceed with whatever trajectory was collected. Added tests for both sync and async versions of this scenario. Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com> * Address code review feedback: Fix comments and error messages Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com> * Remove unnecessary try-catch around truncate_trajectory Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com> * Remove docstrings from test functions Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com> * Combine sync and async context window tests into one test Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com> * Verify inputs passed to extract in test Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: TomeHirata <33407409+TomeHirata@users.noreply.github.com>
1 parent f4a016e commit b95e101

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

dspy/predict/react.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input
153153
except ContextWindowExceededError:
154154
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
155155
trajectory = self.truncate_trajectory(trajectory)
156+
raise ValueError(
157+
"The context window was exceeded even after 3 attempts to truncate the trajectory."
158+
)
156159

157160
async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args):
158161
for _ in range(3):
@@ -164,6 +167,9 @@ async def _async_call_with_potential_trajectory_truncation(self, module, traject
164167
except ContextWindowExceededError:
165168
logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.")
166169
trajectory = self.truncate_trajectory(trajectory)
170+
raise ValueError(
171+
"The context window was exceeded even after 3 attempts to truncate the trajectory."
172+
)
167173

168174
def truncate_trajectory(self, trajectory):
169175
"""Truncates the trajectory so that it fits in the context window.

tests/predict/test_react.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,54 @@ def mock_react(**kwargs):
204204
assert result.output_text == "Final output"
205205

206206

207+
@pytest.mark.asyncio
208+
async def test_context_window_exceeded_after_retries():
209+
def echo(text: str) -> str:
210+
return f"Echoed: {text}"
211+
212+
react = dspy.ReAct("input_text -> output_text", tools=[echo])
213+
214+
def mock_react(**kwargs):
215+
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")
216+
217+
# Test sync version
218+
extract_calls = []
219+
220+
def mock_extract(**kwargs):
221+
extract_calls.append(kwargs)
222+
return dspy.Prediction(output_text="Fallback output")
223+
224+
react.react = mock_react
225+
react.extract = mock_extract
226+
227+
result = react(input_text="test input")
228+
assert result.trajectory == {}
229+
assert result.output_text == "Fallback output"
230+
assert len(extract_calls) == 1
231+
assert extract_calls[0]["input_text"] == "test input"
232+
assert "trajectory" in extract_calls[0]
233+
234+
# Test async version
235+
async_extract_calls = []
236+
237+
async def mock_react_async(**kwargs):
238+
raise litellm.ContextWindowExceededError("Context window exceeded", "dummy_model", "dummy_provider")
239+
240+
async def mock_extract_async(**kwargs):
241+
async_extract_calls.append(kwargs)
242+
return dspy.Prediction(output_text="Fallback output")
243+
244+
react.react.acall = mock_react_async
245+
react.extract.acall = mock_extract_async
246+
247+
result = await react.acall(input_text="test input")
248+
assert result.trajectory == {}
249+
assert result.output_text == "Fallback output"
250+
assert len(async_extract_calls) == 1
251+
assert async_extract_calls[0]["input_text"] == "test input"
252+
assert "trajectory" in async_extract_calls[0]
253+
254+
207255
def test_error_retry():
208256
# --- a tiny tool that always fails -------------------------------------
209257
def foo(a, b):

0 commit comments

Comments
 (0)