Skip to content

Commit 02a4169

Browse files
authored
[Tests] Tool call tests for openai/gpt-oss-20b (#26237)
Signed-off-by: Debolina Roy <debroy@redhat.com>
1 parent 7b5575f commit 02a4169

File tree

2 files changed

+360
-0
lines changed

2 files changed

+360
-0
lines changed

requirements/rocm-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ blobfile==3.0.0
4949
# Multi-Modal Models Test
5050
decord==0.6.0
5151
# video processing, required by entrypoints/openai/test_video.py
52+
rapidfuzz==3.12.1
5253

5354
# OpenAI compatibility and testing
5455
gpt-oss==0.0.8
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
6+
import jsonschema
7+
import openai
8+
import pytest
9+
import pytest_asyncio
10+
from rapidfuzz import fuzz
11+
12+
from ....utils import RemoteOpenAIServer
13+
14+
MODEL_NAME = "openai/gpt-oss-20b"
15+
16+
17+
@pytest.fixture(scope="module")
18+
def server():
19+
args = [
20+
"--max-model-len",
21+
"8192",
22+
"--enforce-eager",
23+
"--enable-auto-tool-choice",
24+
"--tool-call-parser",
25+
"openai",
26+
]
27+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
28+
yield remote_server
29+
30+
31+
@pytest_asyncio.fixture
32+
async def client(server):
33+
"""Async fixture providing an OpenAI-compatible vLLM client."""
34+
async with server.get_async_client() as async_client:
35+
yield async_client
36+
37+
38+
# ==========================================================
39+
# Tool Definitions
40+
# ==========================================================
41+
TOOLS = [
42+
{
43+
"type": "function",
44+
"function": {
45+
"name": "calculator",
46+
"description": "Performs basic arithmetic calculations.",
47+
"parameters": {
48+
"type": "object",
49+
"properties": {
50+
"expression": {
51+
"type": "string",
52+
"description": (
53+
"Arithmetic expression to evaluate, e.g. '123 + 456'."
54+
),
55+
}
56+
},
57+
"required": ["expression"],
58+
},
59+
},
60+
},
61+
{
62+
"type": "function",
63+
"function": {
64+
"name": "get_time",
65+
"description": "Retrieves the current local time for a given city.",
66+
"parameters": {
67+
"type": "object",
68+
"properties": {
69+
"city": {
70+
"type": "string",
71+
"description": "City name, e.g. 'New York'.",
72+
}
73+
},
74+
"required": ["city"],
75+
},
76+
},
77+
},
78+
]
79+
80+
81+
# ==========================================================
82+
# Message Examples
83+
# ==========================================================
84+
MESSAGES_CALC = [
85+
{"role": "user", "content": "Calculate 123 + 456 using the calculator."}
86+
]
87+
88+
MESSAGES_GET_TIME = [
89+
{"role": "user", "content": "What is the current time in New York?"}
90+
]
91+
92+
MESSAGES_MULTIPLE_CALLS = [
93+
{
94+
"role": "system",
95+
"content": (
96+
"You can call multiple tools. "
97+
"When using more than one, return single JSON object with tool_calls array"
98+
"containing each tool call with its function name and arguments. "
99+
"Do not output multiple JSON objects separately."
100+
),
101+
},
102+
{
103+
"role": "user",
104+
"content": "First, calculate 7 * 8 using the calculator. "
105+
"Then, use get_time to tell me the current time in New York.",
106+
},
107+
]
108+
109+
MESSAGES_INVALID_CALL = [
110+
{
111+
"role": "user",
112+
"content": "Can you help with something, "
113+
"but don’t actually perform any calculation?",
114+
}
115+
]
116+
117+
118+
# Expected outputs
119+
FUNC_CALC = "calculator"
120+
FUNC_ARGS_CALC = '{"expression":"123 + 456"}'
121+
122+
FUNC_TIME = "get_time"
123+
FUNC_ARGS_TIME = '{"city": "New York"}'
124+
125+
126+
# ==========================================================
127+
# Utility to extract reasoning and tool calls
128+
# ==========================================================
129+
def extract_reasoning_and_calls(chunks: list) -> tuple[str, list[str], list[str]]:
130+
"""
131+
Extract accumulated reasoning text and tool call arguments
132+
from streaming chunks.
133+
"""
134+
reasoning_content: str = ""
135+
tool_calls: dict[int, dict[str, str]] = {}
136+
137+
for chunk in chunks:
138+
choice = getattr(chunk.choices[0], "delta", None)
139+
if not choice:
140+
continue
141+
142+
if hasattr(choice, "reasoning_content") and choice.reasoning_content:
143+
reasoning_content += choice.reasoning_content
144+
145+
for tc in getattr(choice, "tool_calls", []) or []:
146+
idx = getattr(tc, "index", 0)
147+
tool_entry = tool_calls.setdefault(idx, {"name": "", "arguments": ""})
148+
149+
if getattr(tc, "function", None):
150+
func = tc.function
151+
if getattr(func, "name", None):
152+
tool_entry["name"] = func.name
153+
if getattr(func, "arguments", None):
154+
tool_entry["arguments"] += func.arguments
155+
156+
function_names: list[str] = [v["name"] for _, v in sorted(tool_calls.items())]
157+
arguments: list[str] = [v["arguments"] for _, v in sorted(tool_calls.items())]
158+
159+
return reasoning_content, arguments, function_names
160+
161+
162+
# ==========================================================
163+
# Test Scenarios
164+
# ==========================================================
165+
@pytest.mark.asyncio
166+
async def test_calculator_tool_call_and_argument_accuracy(client: openai.AsyncOpenAI):
167+
"""Verify calculator tool call is made and arguments are accurate."""
168+
169+
response = await client.chat.completions.create(
170+
model=MODEL_NAME,
171+
messages=MESSAGES_CALC,
172+
tools=TOOLS,
173+
temperature=0.0,
174+
stream=False,
175+
)
176+
177+
message = response.choices[0].message
178+
tool_calls = getattr(message, "tool_calls", [])
179+
assert tool_calls, "No tool calls detected"
180+
181+
calc_call = next((c for c in tool_calls if c.function.name == FUNC_CALC), None)
182+
assert calc_call, "Calculator function not called"
183+
184+
raw_args = calc_call.function.arguments
185+
assert raw_args, "Calculator arguments missing"
186+
assert "123" in raw_args and "456" in raw_args, (
187+
f"Expected values not in raw arguments: {raw_args}"
188+
)
189+
190+
try:
191+
parsed_args = json.loads(raw_args)
192+
except json.JSONDecodeError:
193+
pytest.fail(f"Invalid JSON in calculator arguments: {raw_args}")
194+
195+
expected_expr = "123 + 456"
196+
actual_expr = parsed_args.get("expression", "")
197+
similarity = fuzz.ratio(actual_expr, expected_expr)
198+
199+
assert similarity > 90, (
200+
f"Expression mismatch: expected '{expected_expr}' "
201+
f"got '{actual_expr}' (similarity={similarity}%)"
202+
)
203+
204+
205+
@pytest.mark.asyncio
206+
async def test_streaming_tool_call_get_time_with_reasoning(client: openai.AsyncOpenAI):
207+
"""Verify streamed reasoning and tool call behavior for get_time."""
208+
209+
stream = await client.chat.completions.create(
210+
model=MODEL_NAME,
211+
messages=MESSAGES_GET_TIME,
212+
tools=TOOLS,
213+
temperature=0.0,
214+
stream=True,
215+
)
216+
217+
chunks = [chunk async for chunk in stream]
218+
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
219+
220+
assert FUNC_TIME in function_names, "get_time function not called"
221+
222+
assert any("New York" in arg for arg in arguments), (
223+
f"Expected get_time arguments for New York not found in {arguments}"
224+
)
225+
226+
assert len(reasoning) > 0, "Expected reasoning content missing"
227+
228+
assert any(keyword in reasoning for keyword in ["New York", "time", "current"]), (
229+
f"Reasoning is not relevant to the request: {reasoning}"
230+
)
231+
232+
233+
@pytest.mark.asyncio
234+
async def test_streaming_multiple_tools(client: openai.AsyncOpenAI):
235+
"""Test streamed multi-tool response with reasoning."""
236+
stream = await client.chat.completions.create(
237+
model=MODEL_NAME,
238+
messages=MESSAGES_MULTIPLE_CALLS,
239+
tools=TOOLS,
240+
temperature=0.0,
241+
stream=True,
242+
)
243+
244+
chunks = [chunk async for chunk in stream]
245+
reasoning, arguments, function_names = extract_reasoning_and_calls(chunks)
246+
247+
try:
248+
assert FUNC_CALC in function_names, (
249+
f"Calculator tool missing — found {function_names}"
250+
)
251+
assert FUNC_TIME in function_names, (
252+
f"Time tool missing — found {function_names}"
253+
)
254+
assert len(reasoning) > 0, "Expected reasoning content in streamed response"
255+
except AssertionError as e:
256+
print(f"ERROR: {e}")
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_invalid_tool_call(client: openai.AsyncOpenAI):
261+
"""
262+
Verify that ambiguous instructions that should not trigger a tool
263+
do not produce any tool calls.
264+
"""
265+
response = await client.chat.completions.create(
266+
model=MODEL_NAME,
267+
messages=MESSAGES_INVALID_CALL,
268+
tools=TOOLS,
269+
temperature=0.0,
270+
stream=False,
271+
)
272+
273+
message = response.choices[0].message
274+
275+
assert message is not None, "Expected message in response"
276+
assert hasattr(message, "content"), "Expected 'content' field in message"
277+
278+
tool_calls = getattr(message, "tool_calls", [])
279+
assert not tool_calls, (
280+
f"Model unexpectedly attempted a tool call on invalid input: {tool_calls}"
281+
)
282+
283+
284+
@pytest.mark.asyncio
285+
async def test_tool_call_with_temperature(client: openai.AsyncOpenAI):
286+
"""
287+
Verify model produces valid tool or text output
288+
under non-deterministic sampling.
289+
"""
290+
response = await client.chat.completions.create(
291+
model=MODEL_NAME,
292+
messages=MESSAGES_CALC,
293+
tools=TOOLS,
294+
temperature=0.7,
295+
stream=False,
296+
)
297+
298+
message = response.choices[0].message
299+
assert message is not None, "Expected non-empty message in response"
300+
assert message.tool_calls or message.content, (
301+
"Response missing both text and tool calls"
302+
)
303+
304+
print(f"\nTool calls: {message.tool_calls}")
305+
print(f"Text: {message.content}")
306+
307+
308+
@pytest.mark.asyncio
309+
async def test_tool_response_schema_accuracy(client: openai.AsyncOpenAI):
310+
"""Validate that tool call arguments adhere to their declared JSON schema."""
311+
response = await client.chat.completions.create(
312+
model=MODEL_NAME,
313+
messages=MESSAGES_MULTIPLE_CALLS,
314+
tools=TOOLS,
315+
temperature=0.0,
316+
)
317+
318+
calls = response.choices[0].message.tool_calls
319+
assert calls, "No tool calls produced"
320+
321+
for call in calls:
322+
func_name = call.function.name
323+
args = json.loads(call.function.arguments)
324+
325+
schema: dict[str, object] | None = None
326+
for tool_entry in TOOLS:
327+
function_def = tool_entry.get("function")
328+
if (
329+
function_def
330+
and isinstance(function_def, dict)
331+
and function_def.get("name") == func_name
332+
):
333+
schema = function_def.get("parameters")
334+
break
335+
336+
assert schema is not None, f"No matching tool schema found for {func_name}"
337+
338+
jsonschema.validate(instance=args, schema=schema)
339+
340+
341+
@pytest.mark.asyncio
342+
async def test_semantic_consistency_with_temperature(client: openai.AsyncOpenAI):
343+
"""Test that temperature variation doesn't cause contradictory reasoning."""
344+
responses = []
345+
for temp in [0.0, 0.5, 1.0]:
346+
resp = await client.chat.completions.create(
347+
model=MODEL_NAME,
348+
messages=MESSAGES_CALC,
349+
tools=TOOLS,
350+
temperature=temp,
351+
)
352+
text = (resp.choices[0].message.content or "").strip()
353+
responses.append(text)
354+
355+
# Compare fuzzy similarity between low- and mid-temperature outputs
356+
low_mid_sim = fuzz.ratio(responses[0], responses[1])
357+
assert low_mid_sim > 60, (
358+
f"Semantic drift too large between T=0.0 and T=0.5 ({low_mid_sim}%)"
359+
)

0 commit comments

Comments
 (0)