|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import re |
3 | 4 |
|
4 | 5 | import pytest |
5 | 6 | from jsonschema import validate |
6 | 7 |
|
7 | 8 | from lmdeploy import pipeline |
8 | | -from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig |
| 9 | +from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig, Response |
9 | 10 |
|
10 | 11 | MODEL_IDS = [ |
11 | 12 | 'Qwen/Qwen3-0.6B', |
@@ -95,3 +96,59 @@ def test_guided_matrix(model_id, backend_name, backend_factory, schema_type): |
95 | 96 | assert re.fullmatch(schema, response[0].text) |
96 | 97 | finally: |
97 | 98 | pipe.close() |
| 99 | + |
| 100 | + |
| 101 | +async def collect(*aiters): |
| 102 | + results = [[] for _ in range(len(aiters))] |
| 103 | + |
| 104 | + async def drain(idx, aiter): |
| 105 | + async for item in aiter: |
| 106 | + results[idx].append(item) |
| 107 | + |
| 108 | + await asyncio.gather(*(drain(idx, aiter) for idx, aiter in enumerate(aiters))) |
| 109 | + |
| 110 | + responses = [] |
| 111 | + for r in results: |
| 112 | + resp = Response(text='', input_token_len=0, generate_token_len=0) |
| 113 | + responses.append(resp) |
| 114 | + for out in r: |
| 115 | + resp.text += out.response |
| 116 | + resp.input_token_len = out.input_token_len |
| 117 | + resp.generate_token_len = out.generate_token_len |
| 118 | + resp.finish_reason = out.finish_reason |
| 119 | + |
| 120 | + return responses |
| 121 | + |
| 122 | +@pytest.mark.parametrize('model_id', MODEL_IDS) |
| 123 | +@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES) |
| 124 | +def test_mix_guided_matrix(model_id, backend_name, backend_factory): |
| 125 | + pipe = pipeline( |
| 126 | + model_id, |
| 127 | + backend_config=backend_factory(), |
| 128 | + log_level='INFO', |
| 129 | + ) |
| 130 | + |
| 131 | + schema_type = 'json_schema' |
| 132 | + response_format = {'type': schema_type} |
| 133 | + schema = SCHEMA_MAP[schema_type] |
| 134 | + response_format[schema_type] = dict(name='test', schema=schema) |
| 135 | + |
| 136 | + gen_config = GenerationConfig(response_format=response_format) |
| 137 | + |
| 138 | + configs = [None if idx % 3 else gen_config for idx in range(4)] |
| 139 | + tasks = [ |
| 140 | + pipe.generate( |
| 141 | + messages='Make a self introduction please.', |
| 142 | + session_id=session_id, |
| 143 | + gen_config=gen_config |
| 144 | + ) for session_id, gen_config in enumerate(configs) |
| 145 | + ] |
| 146 | + |
| 147 | + responses = asyncio.run(collect(*tasks)) |
| 148 | + for resp, config in zip(responses, configs): |
| 149 | + if config is None: |
| 150 | + assert '}' not in resp.text |
| 151 | + else: |
| 152 | + validate(instance=json.loads(resp.text), schema=schema) |
| 153 | + |
| 154 | + |
0 commit comments