Skip to content

Commit c413dbe

Browse files
committed
test: add mixing guided and non-guided tests
1 parent 22f4ea3 commit c413dbe

File tree

1 file changed

+54
-1
lines changed

1 file changed

+54
-1
lines changed

tests/test_lmdeploy/test_grammar.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import asyncio
12
import json
23
import re
34

45
import pytest
56
from jsonschema import validate
67

78
from lmdeploy import pipeline
8-
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
9+
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, TurbomindEngineConfig
910

1011
MODEL_IDS = [
1112
'Qwen/Qwen3-0.6B',
@@ -95,3 +96,55 @@ def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
9596
assert re.fullmatch(schema, response[0].text)
9697
finally:
9798
pipe.close()
99+
100+
101+
async def collect(*aiters):
102+
results = [[] for _ in range(len(aiters))]
103+
104+
async def drain(idx, aiter):
105+
async for item in aiter:
106+
results[idx].append(item)
107+
108+
await asyncio.gather(*(drain(idx, aiter) for idx, aiter in enumerate(aiters)))
109+
110+
responses = []
111+
for r in results:
112+
resp = Response(text='', input_token_len=0, generate_token_len=0)
113+
responses.append(resp)
114+
for out in r:
115+
resp.text += out.response
116+
resp.input_token_len = out.input_token_len
117+
resp.generate_token_len = out.generate_token_len
118+
resp.finish_reason = out.finish_reason
119+
120+
return responses
121+
122+
123+
@pytest.mark.parametrize('model_id', MODEL_IDS)
124+
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
125+
def test_mix_guided_matrix(model_id, backend_name, backend_factory):
126+
pipe = pipeline(
127+
model_id,
128+
backend_config=backend_factory(),
129+
log_level='INFO',
130+
)
131+
132+
schema_type = 'json_schema'
133+
response_format = {'type': schema_type}
134+
schema = SCHEMA_MAP[schema_type]
135+
response_format[schema_type] = dict(name='test', schema=schema)
136+
137+
gen_config = GenerationConfig(response_format=response_format)
138+
139+
configs = [None if idx % 3 else gen_config for idx in range(4)]
140+
tasks = [
141+
pipe.generate(messages='Make a self introduction please.', session_id=session_id, gen_config=gen_config)
142+
for session_id, gen_config in enumerate(configs)
143+
]
144+
145+
responses = asyncio.run(collect(*tasks))
146+
for resp, config in zip(responses, configs):
147+
if config is None:
148+
assert '}' not in resp.text
149+
else:
150+
validate(instance=json.loads(resp.text), schema=schema)

0 commit comments

Comments
 (0)