Skip to content

Commit 1107feb

Browse files
authored
Merge pull request #17803 from BerriAI/litellm_preserve_system_instructions
fix: Preserve systemInstructions for vertex ai generate content request
2 parents 2ea855d + 9344d29 commit 1107feb

File tree

5 files changed

+97
-4
lines changed

5 files changed

+97
-4
lines changed

litellm/google_genai/main.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,15 @@ def setup_generate_content_call(
164164
model=model,
165165
)
166166
)
167+
# Extract systemInstruction from kwargs to pass to transform
168+
system_instruction = kwargs.get("systemInstruction") or kwargs.get("system_instruction")
167169
request_body = (
168170
generate_content_provider_config.transform_generate_content_request(
169171
model=model,
170172
contents=contents,
171173
tools=tools,
172174
generate_content_config_dict=generate_content_config_dict,
175+
system_instruction=system_instruction,
173176
)
174177
)
175178

@@ -311,6 +314,9 @@ def generate_content(
311314
**kwargs,
312315
)
313316

317+
# Extract systemInstruction from kwargs to pass to handler
318+
system_instruction = kwargs.get("systemInstruction") or kwargs.get("system_instruction")
319+
314320
# Check if we should use the adapter (when provider config is None)
315321
if setup_result.generate_content_provider_config is None:
316322
# Use the adapter to convert to completion format
@@ -340,6 +346,7 @@ def generate_content(
340346
_is_async=_is_async,
341347
client=kwargs.get("client"),
342348
litellm_metadata=kwargs.get("litellm_metadata", {}),
349+
system_instruction=system_instruction,
343350
)
344351

345352
return response
@@ -395,6 +402,9 @@ async def agenerate_content_stream(
395402
**kwargs,
396403
)
397404

405+
# Extract systemInstruction from kwargs to pass to handler
406+
system_instruction = kwargs.get("systemInstruction") or kwargs.get("system_instruction")
407+
398408
# Check if we should use the adapter (when provider config is None)
399409
if setup_result.generate_content_provider_config is None:
400410
# Use the adapter to convert to completion format
@@ -428,6 +438,7 @@ async def agenerate_content_stream(
428438
client=kwargs.get("client"),
429439
stream=True,
430440
litellm_metadata=kwargs.get("litellm_metadata", {}),
441+
system_instruction=system_instruction,
431442
)
432443

433444
except Exception as e:

litellm/llms/base_llm/google_genai/transformation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def transform_generate_content_request(
149149
contents: GenerateContentContentListUnionDict,
150150
tools: Optional[ToolConfigDict],
151151
generate_content_config_dict: Dict,
152+
system_instruction: Optional[Any] = None,
152153
) -> dict:
153154
"""
154155
Transform the request parameters for the generate content API.
@@ -157,9 +158,8 @@ def transform_generate_content_request(
157158
model: The model name
158159
contents: Input contents
159160
tools: Tools
160-
generate_content_request_params: Request parameters
161-
litellm_params: LiteLLM parameters
162-
headers: Request headers
161+
generate_content_config_dict: Generation config parameters
162+
system_instruction: Optional system instruction
163163
164164
Returns:
165165
Transformed request data

litellm/llms/custom_httpx/llm_http_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7311,6 +7311,7 @@ def generate_content_handler(
73117311
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
73127312
stream: bool = False,
73137313
litellm_metadata: Optional[Dict[str, Any]] = None,
7314+
system_instruction: Optional[Any] = None,
73147315
) -> Any:
73157316
"""
73167317
Handles Google GenAI generate content requests.
@@ -7336,6 +7337,7 @@ def generate_content_handler(
73367337
client=client if isinstance(client, AsyncHTTPHandler) else None,
73377338
stream=stream,
73387339
litellm_metadata=litellm_metadata,
7340+
system_instruction=system_instruction,
73397341
)
73407342

73417343
if client is None or not isinstance(client, HTTPHandler):
@@ -7365,6 +7367,7 @@ def generate_content_handler(
73657367
contents=contents,
73667368
tools=tools,
73677369
generate_content_config_dict=generate_content_config_dict,
7370+
system_instruction=system_instruction,
73687371
)
73697372

73707373
if extra_body:
@@ -7435,6 +7438,7 @@ async def async_generate_content_handler(
74357438
client: Optional[AsyncHTTPHandler] = None,
74367439
stream: bool = False,
74377440
litellm_metadata: Optional[Dict[str, Any]] = None,
7441+
system_instruction: Optional[Any] = None,
74387442
) -> Any:
74397443
"""
74407444
Async version of the generate content handler.
@@ -7472,6 +7476,7 @@ async def async_generate_content_handler(
74727476
contents=contents,
74737477
tools=tools,
74747478
generate_content_config_dict=generate_content_config_dict,
7479+
system_instruction=system_instruction,
74757480
)
74767481

74777482
if extra_body:

litellm/llms/gemini/google_genai/transformation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def transform_generate_content_request(
272272
contents: GenerateContentContentListUnionDict,
273273
tools: Optional[ToolConfigDict],
274274
generate_content_config_dict: Dict,
275+
system_instruction: Optional[Any] = None,
275276
) -> dict:
276277
from litellm.types.google_genai.main import (
277278
GenerateContentConfigDict,

tests/test_litellm/proxy/google_endpoints/test_google_api_endpoints.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,80 @@ async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config,
233233
assert called_data["litellm_metadata"]["user_api_key_user_id"] == "test-user-id"
234234
assert called_data["litellm_metadata"]["user_api_key_team_id"] == "test-team-id"
235235
# Verify stream is set to True
236-
assert called_data["stream"] is True
236+
assert called_data["stream"] is True
237+
238+
239+
def test_google_generate_content_with_system_instruction():
240+
"""
241+
Test that systemInstruction is correctly passed through from the endpoint to the router.
242+
243+
This test verifies the fix for systemInstruction being dropped when forwarding
244+
requests to Vertex AI through the Google GenAI endpoint.
245+
"""
246+
try:
247+
from fastapi import FastAPI
248+
from fastapi.testclient import TestClient
249+
250+
from litellm.proxy.google_endpoints.endpoints import router as google_router
251+
except ImportError as e:
252+
pytest.skip(f"Skipping test due to missing dependency: {e}")
253+
254+
# Create a FastAPI app and include the router
255+
app = FastAPI()
256+
app.include_router(google_router)
257+
258+
# Create a test client
259+
client = TestClient(app)
260+
261+
# Mock all required proxy server dependencies
262+
with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \
263+
patch("litellm.proxy.proxy_server.general_settings", {}), \
264+
patch("litellm.proxy.proxy_server.proxy_config") as mock_proxy_config, \
265+
patch("litellm.proxy.proxy_server.version", "1.0.0"), \
266+
patch("litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request") as mock_add_data:
267+
268+
mock_router.agenerate_content = AsyncMock(return_value={"test": "response"})
269+
270+
# Mock add_litellm_data_to_request to pass through data unchanged
271+
async def mock_add_litellm_data(data, request, user_api_key_dict, proxy_config, general_settings, version):
272+
return data
273+
274+
mock_add_data.side_effect = mock_add_litellm_data
275+
276+
# Define the systemInstruction to test
277+
system_instruction = {
278+
"parts": [{"text": "Your name is Doodle."}]
279+
}
280+
281+
# Send a request with systemInstruction
282+
response = client.post(
283+
"/v1beta/models/gemini-2.5-pro:generateContent",
284+
json={
285+
"systemInstruction": system_instruction,
286+
"contents": [
287+
{
288+
"parts": [{"text": "What is your name?"}],
289+
"role": "user"
290+
}
291+
]
292+
},
293+
headers={"Authorization": "Bearer sk-test-key"}
294+
)
295+
296+
# Verify the response
297+
assert response.status_code == 200
298+
299+
# Verify that agenerate_content was called
300+
mock_router.agenerate_content.assert_called_once()
301+
call_args = mock_router.agenerate_content.call_args
302+
called_data = call_args[1]
303+
304+
# Verify that systemInstruction is present in the call arguments
305+
assert "systemInstruction" in called_data
306+
assert called_data["systemInstruction"] == system_instruction
307+
assert called_data["systemInstruction"]["parts"][0]["text"] == "Your name is Doodle."
308+
309+
# Verify contents are also present
310+
assert "contents" in called_data
311+
assert len(called_data["contents"]) == 1
312+
assert called_data["contents"][0]["role"] == "user"

0 commit comments

Comments
 (0)