Skip to content

Commit 240f8bd

Browse files
committed
test: add mixing guided and non-guided tests
1 parent 22f4ea3 commit 240f8bd

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

tests/test_lmdeploy/test_grammar.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import asyncio
12
import json
23
import re
34

45
import pytest
56
from jsonschema import validate
67

78
from lmdeploy import pipeline
8-
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
9+
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, Response
910

1011
MODEL_IDS = [
1112
'Qwen/Qwen3-0.6B',
@@ -95,3 +96,59 @@ def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
9596
assert re.fullmatch(schema, response[0].text)
9697
finally:
9798
pipe.close()
99+
100+
101+
async def collect(*aiters):
102+
results = [[] for _ in range(len(aiters))]
103+
104+
async def drain(idx, aiter):
105+
async for item in aiter:
106+
results[idx].append(item)
107+
108+
await asyncio.gather(*(drain(idx, aiter) for idx, aiter in enumerate(aiters)))
109+
110+
responses = []
111+
for r in results:
112+
resp = Response(text='', input_token_len=0, generate_token_len=0)
113+
responses.append(resp)
114+
for out in r:
115+
resp.text += out.response
116+
resp.input_token_len = out.input_token_len
117+
resp.generate_token_len = out.generate_token_len
118+
resp.finish_reason = out.finish_reason
119+
120+
return responses
121+
122+
@pytest.mark.parametrize('model_id', MODEL_IDS)
123+
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
124+
def test_mix_guided_matrix(model_id, backend_name, backend_factory):
125+
pipe = pipeline(
126+
model_id,
127+
backend_config=backend_factory(),
128+
log_level='INFO',
129+
)
130+
131+
schema_type = 'json_schema'
132+
response_format = {'type': schema_type}
133+
schema = SCHEMA_MAP[schema_type]
134+
response_format[schema_type] = dict(name='test', schema=schema)
135+
136+
gen_config = GenerationConfig(response_format=response_format)
137+
138+
configs = [None if idx % 3 else gen_config for idx in range(4)]
139+
tasks = [
140+
pipe.generate(
141+
messages='Make a self introduction please.',
142+
session_id=session_id,
143+
gen_config=gen_config
144+
) for session_id, gen_config in enumerate(configs)
145+
]
146+
147+
responses = asyncio.run(collect(*tasks))
148+
for resp, config in zip(responses, configs):
149+
if config is None:
150+
assert '}' not in resp.text
151+
else:
152+
validate(instance=json.loads(resp.text), schema=schema)
153+
154+

0 commit comments

Comments
 (0)