Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ async def async_stream_infer(self,
while not state or state.status == 0:
await sem.acquire()
state = shared_state.consume()

self.model_inst.clear_grammar()
logger.info(f'[async_stream_infer] session {session_id} done')

def _get_error_output(self, status):
Expand Down
19 changes: 16 additions & 3 deletions src/turbomind/engine/model_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@

namespace turbomind {

ModelRequest::ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim):
ModelRequest::ModelRequest(
Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim, int tp_size):
gateway_{gateway},
data_type_{data_type},
session_len_{session_len},
vocab_size_{vocab_size},
hidden_dim_{hidden_dim}
hidden_dim_{hidden_dim},
tp_size_{tp_size}
{
}

Expand Down Expand Up @@ -127,8 +129,14 @@ auto ModelRequest::Forward(InputParam param, std::function<void()> cb) -> Output
r->output_ids = outputs_->at("output_ids");
r->sequence_length = outputs_->at("sequence_length");

r->matchers.clear();
if (grammar_) {
r->matcher = std::make_shared<xgrammar::GrammarMatcher>(*grammar_);
for (int i = 0; i < tp_size_; ++i) {
r->matchers.push_back(std::make_shared<xgrammar::GrammarMatcher>(*grammar_));
}
}
else {
r->matchers.resize(tp_size_);
}

// Keep a weak reference for canceling the request
Expand All @@ -144,4 +152,9 @@ void ModelRequest::setGrammar(const xgrammar::CompiledGrammar& grammar)
grammar_ = std::make_shared<xgrammar::CompiledGrammar>(grammar);
}

void ModelRequest::clearGrammar()
{
grammar_.reset();
}

} // namespace turbomind
4 changes: 3 additions & 1 deletion src/turbomind/engine/model_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ModelRequest {
public:
virtual ~ModelRequest() = default;

ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim);
ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim, int tp_size);

// Cancel running request
void Cancel();
Expand All @@ -41,6 +41,7 @@ class ModelRequest {

OutputParam Forward(InputParam param, std::function<void()> cb);
void setGrammar(const xgrammar::CompiledGrammar& grammar);
void clearGrammar();

protected:
Gateway* const gateway_;
Expand All @@ -50,6 +51,7 @@ class ModelRequest {
const int session_len_;
const int hidden_dim_;
const int vocab_size_;
const int tp_size_;

uint64_t session_id_;

Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/engine/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ struct Request {
kInconsistency = 9, // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode
};

std::shared_ptr<xgrammar::GrammarMatcher> matcher;
std::vector<std::shared_ptr<xgrammar::GrammarMatcher>> matchers; // GrammarMatchers for different threads (tp_size)
};

inline void UpdateState(Request& r, int status, int seq_len)
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/layers/BaseDynamicDecodeLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BaseDynamicDecodeLayer {
int vocab_size_padded;
cudaStream_t stream;
const cudaDeviceProp* device_prop;
int tp_rank;
};

virtual ~BaseDynamicDecodeLayer() = default;
Expand All @@ -42,6 +43,7 @@ class BaseDynamicDecodeLayer {
vocab_size_padded_ = param.vocab_size_padded;
stream_ = param.stream;
device_prop_ = param.device_prop;
tp_rank_ = param.tp_rank;
};

virtual void Setup(const std::vector<const Request*>& rs, const TensorMap& args) = 0;
Expand All @@ -54,6 +56,7 @@ class BaseDynamicDecodeLayer {
int vocab_size_padded_;
cudaStream_t stream_;
const cudaDeviceProp* device_prop_;
int tp_rank_;
};

} // namespace turbomind
7 changes: 5 additions & 2 deletions src/turbomind/layers/DynamicDecodeLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ DynamicDecodeLayer::DynamicDecodeLayer(DataType dtype,
int vocab_size,
int vocab_size_padded,
cudaStream_t stream,
const cudaDeviceProp* device_prop)
const cudaDeviceProp* device_prop,
int tp_rank):
tp_rank_{tp_rank}
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TM_CHECK(dtype == kFloat32);
BaseDynamicDecodeLayer::BaseParam param{max_batch_size, vocab_size, vocab_size_padded, stream, device_prop};
BaseDynamicDecodeLayer::BaseParam param{
max_batch_size, vocab_size, vocab_size_padded, stream, device_prop, tp_rank};
layers_.emplace_back(new LogitsProcessorLayer<float>{param});
layers_.emplace_back(new GuidedDecodeMaskLayer<float>{param});
layers_.emplace_back(new SamplingLayer<float>{param});
Expand Down
4 changes: 3 additions & 1 deletion src/turbomind/layers/DynamicDecodeLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class DynamicDecodeLayer {
int vocab_size,
int vocab_size_padded,
cudaStream_t stream,
const cudaDeviceProp* device_prop);
const cudaDeviceProp* device_prop,
int tp_rank);

~DynamicDecodeLayer();

Expand All @@ -42,6 +43,7 @@ class DynamicDecodeLayer {
void Forward(TensorMap& args);

private:
int tp_rank_;
std::vector<std::unique_ptr<BaseDynamicDecodeLayer>> layers_;
};

Expand Down
21 changes: 10 additions & 11 deletions src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void GuidedDecodeMaskLayer<T>::Setup(const std::vector<const Request*>& rs, cons
TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
matchers_.clear();
for (const auto& r : rs) {
matchers_.push_back(r->matcher);
matchers_.push_back(r->matchers[tp_rank_]);
}
}

Expand All @@ -47,17 +47,16 @@ void GuidedDecodeMaskLayer<T>::Forward(TensorMap& args)

TM_CHECK(bsz == matchers_.size());

const auto bitmask_size = bitmask_buf_.shape(1);
std::vector<int64_t> bitmask_shape = {bsz, bitmask_size};
const auto bitmask_size = bitmask_buf_.shape(1);
std::vector<int64_t> bitmask_shape = {bsz, bitmask_size};
int32_t* data = bitmask_buf_.data();
size_t total_elements = bitmask_shape[0] * bitmask_shape[1];
bool need_apply = false;
DLTensor bitmask_dltensor{
data, DLDevice{kDLCPU, 0}, bitmask_buf_.ndim(), xgrammar::GetBitmaskDLType(), bitmask_shape.data(), nullptr, 0};

std::fill(data, data + total_elements, 0xffffffff);

DLTensor bitmask_dltensor{bitmask_buf_.data(),
DLDevice{kDLCPU, 0},
bitmask_buf_.ndim(),
xgrammar::GetBitmaskDLType(),
bitmask_shape.data(),
nullptr,
0};
bool need_apply = false;
for (size_t i = 0; i < bsz; ++i) {
const auto& matcher = matchers_[i];
if (matcher) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include "src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h"
#include "src/turbomind/core/context.h"

namespace turbomind {

Expand All @@ -29,7 +30,7 @@ void GuidedDecodeUpdateLayer<T>::Setup(const std::vector<const Request*>& rs, co
TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
matchers_.clear();
for (const auto& r : rs) {
matchers_.push_back(r->matcher);
matchers_.push_back(r->matchers[tp_rank_]);
}
}

Expand All @@ -45,6 +46,7 @@ void GuidedDecodeUpdateLayer<T>::Forward(TensorMap& args)

FT_CHECK(bsz == matchers_.size());
Copy(output_ids.slice(step * bsz, bsz), output_ids_buf);
core::Context::stream().Sync();

for (size_t i = 0; i < bsz; ++i) {
const auto& matcher = matchers_[i];
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ LlamaV2::LlamaV2(DataType dtype,

// using float to avoid data overflow
dynamic_decode_ = std::make_unique<DynamicDecodeLayer>(
kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop);
kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop, engine.mlp_tp_rank);
}

void LlamaV2::updateEmbedding(char* decoder_input,
Expand Down
9 changes: 8 additions & 1 deletion src/turbomind/python/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,14 @@ PYBIND11_MODULE(_turbomind, m)
model_request->setGrammar(grammar);
},
py::call_guard<py::gil_scoped_release>(),
"grammar"_a);
"grammar"_a)
.def(
"clear_grammar",
[](ModelRequest* model_request) {
TM_LOG_INFO("Release grammar for model_request");
model_request->clearGrammar();
},
py::call_guard<py::gil_scoped_release>());

// transformer model
using ft::LlamaTritonModel;
Expand Down
8 changes: 6 additions & 2 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,12 @@ std::unique_ptr<ModelRequest> LlamaTritonModel::createModelInstance(int device_i
{
FT_CHECK(engines_[device_id] != nullptr);

return std::make_unique<ModelRequest>(
gateway_.get(), dtype_, engine_param_.session_len, model_param_.vocab_size, model_param_.hidden_units);
return std::make_unique<ModelRequest>(gateway_.get(),
dtype_,
engine_param_.session_len,
model_param_.vocab_size,
model_param_.hidden_units,
comm_size_);
}

void LlamaTritonModel::createSharedWeights(int device_id, int rank)
Expand Down
55 changes: 54 additions & 1 deletion tests/test_lmdeploy/test_grammar.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import json
import re

import pytest
from jsonschema import validate

from lmdeploy import pipeline
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig
from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, TurbomindEngineConfig

MODEL_IDS = [
'Qwen/Qwen3-0.6B',
Expand Down Expand Up @@ -95,3 +96,55 @@ def test_guided_matrix(model_id, backend_name, backend_factory, schema_type):
assert re.fullmatch(schema, response[0].text)
finally:
pipe.close()


async def collect(*aiters):
results = [[] for _ in range(len(aiters))]

async def drain(idx, aiter):
async for item in aiter:
results[idx].append(item)

await asyncio.gather(*(drain(idx, aiter) for idx, aiter in enumerate(aiters)))

responses = []
for r in results:
resp = Response(text='', input_token_len=0, generate_token_len=0)
responses.append(resp)
for out in r:
resp.text += out.response
resp.input_token_len = out.input_token_len
resp.generate_token_len = out.generate_token_len
resp.finish_reason = out.finish_reason

return responses


@pytest.mark.parametrize('model_id', MODEL_IDS)
@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES)
def test_mix_guided_matrix(model_id, backend_name, backend_factory):
pipe = pipeline(
model_id,
backend_config=backend_factory(),
log_level='INFO',
)

schema_type = 'json_schema'
response_format = {'type': schema_type}
schema = SCHEMA_MAP[schema_type]
response_format[schema_type] = dict(name='test', schema=schema)

gen_config = GenerationConfig(response_format=response_format)

configs = [None if idx % 3 else gen_config for idx in range(4)]
tasks = [
pipe.generate(messages='Make a self introduction please.', session_id=session_id, gen_config=gen_config)
for session_id, gen_config in enumerate(configs)
]

responses = asyncio.run(collect(*tasks))
for resp, config in zip(responses, configs):
if config is None:
assert '}' not in resp.text
else:
validate(instance=json.loads(resp.text), schema=schema)