From adb0148cccd1548013a94f149f58a798a060f613 Mon Sep 17 00:00:00 2001 From: Zhongbo Tian Date: Fri, 28 Nov 2025 18:05:22 +0800 Subject: [PATCH 1/5] fix: fix guided decoding state corruption in turbomind when tp>1 --- src/turbomind/engine/model_request.cc | 14 +++++++++++--- src/turbomind/engine/model_request.h | 3 ++- src/turbomind/engine/request.h | 2 +- src/turbomind/layers/BaseDynamicDecodeLayer.h | 3 +++ src/turbomind/layers/DynamicDecodeLayer.cc | 7 +++++-- src/turbomind/layers/DynamicDecodeLayer.h | 4 +++- .../sampling_layers/GuidedDecodeMaskLayer.cc | 2 +- .../sampling_layers/GuidedDecodeUpdateLayer.cc | 2 +- src/turbomind/models/llama/LlamaV2.cc | 2 +- .../triton_backend/llama/LlamaTritonModel.cc | 8 ++++++-- 10 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index ba7ebe321f..805b6dac5d 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -13,12 +13,14 @@ namespace turbomind { -ModelRequest::ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim): +ModelRequest::ModelRequest( + Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim, int tp_size): gateway_{gateway}, data_type_{data_type}, session_len_{session_len}, vocab_size_{vocab_size}, - hidden_dim_{hidden_dim} + hidden_dim_{hidden_dim}, + tp_size_{tp_size} { } @@ -127,8 +129,14 @@ auto ModelRequest::Forward(InputParam param, std::function cb) -> Output r->output_ids = outputs_->at("output_ids"); r->sequence_length = outputs_->at("sequence_length"); + r->matchers.clear(); if (grammar_) { - r->matcher = std::make_shared(*grammar_); + for (int i = 0; i < tp_size_; ++i) { + r->matchers.push_back(std::make_shared(*grammar_)); + } + } + else { + r->matchers.resize(tp_size_); } // Keep a weak reference for canceling the request diff --git a/src/turbomind/engine/model_request.h b/src/turbomind/engine/model_request.h index 7582163095..b05a980312 100644 --- a/src/turbomind/engine/model_request.h +++ b/src/turbomind/engine/model_request.h @@ -15,7 +15,7 @@ class ModelRequest { public: virtual ~ModelRequest() = default; - ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim); + ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim, int tp_size); // Cancel running request void Cancel(); @@ -50,6 +50,7 @@ class ModelRequest { const int session_len_; const int hidden_dim_; const int vocab_size_; + const int tp_size_; uint64_t session_id_; diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h index aa50a48100..a5661bf82c 100644 --- a/src/turbomind/engine/request.h +++ b/src/turbomind/engine/request.h @@ -154,7 +154,7 @@ struct Request { kInconsistency = 9, // Inconsistent request parameters, e.g. prefix caching is not allowed in interactive mode }; - std::shared_ptr matcher; + std::vector> matchers; // GrammarMatchers for different threads (tp_size) }; inline void UpdateState(Request& r, int status, int seq_len) diff --git a/src/turbomind/layers/BaseDynamicDecodeLayer.h b/src/turbomind/layers/BaseDynamicDecodeLayer.h index a3e14407ff..239b8bc78d 100644 --- a/src/turbomind/layers/BaseDynamicDecodeLayer.h +++ b/src/turbomind/layers/BaseDynamicDecodeLayer.h @@ -31,6 +31,7 @@ class BaseDynamicDecodeLayer { int vocab_size_padded; cudaStream_t stream; const cudaDeviceProp* device_prop; + int tp_rank; }; virtual ~BaseDynamicDecodeLayer() = default; @@ -42,6 +43,7 @@ class BaseDynamicDecodeLayer { vocab_size_padded_ = param.vocab_size_padded; stream_ = param.stream; device_prop_ = param.device_prop; + tp_rank_ = param.tp_rank; }; virtual void Setup(const std::vector& rs, const TensorMap& args) = 0; @@ -54,6 +56,7 @@ class BaseDynamicDecodeLayer { int vocab_size_padded_; cudaStream_t stream_; const cudaDeviceProp* device_prop_; + int tp_rank_; }; } // namespace turbomind diff --git a/src/turbomind/layers/DynamicDecodeLayer.cc b/src/turbomind/layers/DynamicDecodeLayer.cc index 5a66bf1fb6..11a54cb6aa 100644 --- a/src/turbomind/layers/DynamicDecodeLayer.cc +++ b/src/turbomind/layers/DynamicDecodeLayer.cc @@ -31,11 +31,14 @@ DynamicDecodeLayer::DynamicDecodeLayer(DataType dtype, int vocab_size, int vocab_size_padded, cudaStream_t stream, - const cudaDeviceProp* device_prop) + const cudaDeviceProp* device_prop, + int tp_rank): + tp_rank_{tp_rank} { TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_CHECK(dtype == kFloat32); - BaseDynamicDecodeLayer::BaseParam param{max_batch_size, vocab_size, vocab_size_padded, stream, device_prop}; + BaseDynamicDecodeLayer::BaseParam param{ + max_batch_size, vocab_size, vocab_size_padded, stream, device_prop, tp_rank}; layers_.emplace_back(new LogitsProcessorLayer{param}); layers_.emplace_back(new GuidedDecodeMaskLayer{param}); layers_.emplace_back(new SamplingLayer{param}); diff --git a/src/turbomind/layers/DynamicDecodeLayer.h b/src/turbomind/layers/DynamicDecodeLayer.h index c527ff8e0f..233f7f6f9b 100644 --- a/src/turbomind/layers/DynamicDecodeLayer.h +++ b/src/turbomind/layers/DynamicDecodeLayer.h @@ -33,7 +33,8 @@ class DynamicDecodeLayer { int vocab_size, int vocab_size_padded, cudaStream_t stream, - const cudaDeviceProp* device_prop); + const cudaDeviceProp* device_prop, + int tp_rank); ~DynamicDecodeLayer(); @@ -42,6 +43,7 @@ class DynamicDecodeLayer { void Forward(TensorMap& args); private: + int tp_rank_; std::vector> layers_; }; diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc index 2262992902..2371fe56c0 100644 --- a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc @@ -33,7 +33,7 @@ void GuidedDecodeMaskLayer::Setup(const std::vector& rs, cons TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); matchers_.clear(); for (const auto& r : rs) { - matchers_.push_back(r->matcher); + matchers_.push_back(r->matchers[tp_rank_]); } } diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc index 653a8874d8..cb31d306c7 100644 --- a/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc @@ -29,7 +29,7 @@ void GuidedDecodeUpdateLayer::Setup(const std::vector& rs, co TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); matchers_.clear(); for (const auto& r : rs) { - matchers_.push_back(r->matcher); + matchers_.push_back(r->matchers[tp_rank_]); } } diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index 68185eac38..729f9b380a 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -90,7 +90,7 @@ LlamaV2::LlamaV2(DataType dtype, // using float to avoid data overflow dynamic_decode_ = std::make_unique( - kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop); + kFloat32, max_batch_size, vocab_size_, vocab_size_padded_, stream_, &ctx.device_prop, engine.mlp_tp_rank); } void LlamaV2::updateEmbedding(char* decoder_input, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 853b0a96d8..841e538fb4 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -454,8 +454,12 @@ std::unique_ptr LlamaTritonModel::createModelInstance(int device_i { FT_CHECK(engines_[device_id] != nullptr); - return std::make_unique( - gateway_.get(), dtype_, engine_param_.session_len, model_param_.vocab_size, model_param_.hidden_units); + return std::make_unique(gateway_.get(), + dtype_, + engine_param_.session_len, + model_param_.vocab_size, + model_param_.hidden_units, + comm_size_); } void LlamaTritonModel::createSharedWeights(int device_id, int rank) From 98a911700ec8803bbf9f8e15d45cdc4204484c32 Mon Sep 17 00:00:00 2001 From: Zhongbo Tian Date: Thu, 4 Dec 2025 18:28:53 +0800 Subject: [PATCH 2/5] fix: fix a potential synchronize bug when update state --- src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc index cb31d306c7..25c48e3f23 100644 --- a/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.cc @@ -15,6 +15,7 @@ */ #include "src/turbomind/layers/sampling_layers/GuidedDecodeUpdateLayer.h" +#include "src/turbomind/core/context.h" namespace turbomind { @@ -45,6 +46,7 @@ void GuidedDecodeUpdateLayer::Forward(TensorMap& args) FT_CHECK(bsz == matchers_.size()); Copy(output_ids.slice(step * bsz, bsz), output_ids_buf); + core::Context::stream().Sync(); for (size_t i = 0; i < bsz; ++i) { const auto& matcher = matchers_[i]; From fbbbb4f52f4f9fcc45d05f50e757c42a8a0a4ec8 Mon Sep 17 00:00:00 2001 From: Zhongbo Tian Date: Wed, 10 Dec 2025 15:03:36 +0800 Subject: [PATCH 3/5] fix: fix initialized bitmask buffer when mixing guided requests and non-guided requests --- .../sampling_layers/GuidedDecodeMaskLayer.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc index 2371fe56c0..d98c6661cf 100644 --- a/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc +++ b/src/turbomind/layers/sampling_layers/GuidedDecodeMaskLayer.cc @@ -47,17 +47,16 @@ void GuidedDecodeMaskLayer::Forward(TensorMap& args) TM_CHECK(bsz == matchers_.size()); - const auto bitmask_size = bitmask_buf_.shape(1); - std::vector bitmask_shape = {bsz, bitmask_size}; + const auto bitmask_size = bitmask_buf_.shape(1); + std::vector bitmask_shape = {bsz, bitmask_size}; + int32_t* data = bitmask_buf_.data(); + size_t total_elements = bitmask_shape[0] * bitmask_shape[1]; + bool need_apply = false; + DLTensor bitmask_dltensor{ + data, DLDevice{kDLCPU, 0}, bitmask_buf_.ndim(), xgrammar::GetBitmaskDLType(), bitmask_shape.data(), nullptr, 0}; + + std::fill(data, data + total_elements, 0xffffffff); - DLTensor bitmask_dltensor{bitmask_buf_.data(), - DLDevice{kDLCPU, 0}, - bitmask_buf_.ndim(), - xgrammar::GetBitmaskDLType(), - bitmask_shape.data(), - nullptr, - 0}; - bool need_apply = false; for (size_t i = 0; i < bsz; ++i) { const auto& matcher = matchers_[i]; if (matcher) { From 22f4ea3902cf02138a9a856e7bcab2774a7072ad Mon Sep 17 00:00:00 2001 From: Zhongbo Tian Date: Wed, 10 Dec 2025 20:59:45 +0800 Subject: [PATCH 4/5] fix: add clear_grammar to remove grammar from reused model_request --- lmdeploy/turbomind/turbomind.py | 2 ++ src/turbomind/engine/model_request.cc | 5 +++++ src/turbomind/engine/model_request.h | 1 + src/turbomind/python/bind.cpp | 9 ++++++++- 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 67d57e4e5a..c391399ab4 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -814,6 +814,8 @@ async def async_stream_infer(self, while not state or state.status == 0: await sem.acquire() state = shared_state.consume() + + self.model_inst.clear_grammar() logger.info(f'[async_stream_infer] session {session_id} done') def _get_error_output(self, status): diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc index 805b6dac5d..3f0ef92ca3 100644 --- a/src/turbomind/engine/model_request.cc +++ b/src/turbomind/engine/model_request.cc @@ -152,4 +152,9 @@ void ModelRequest::setGrammar(const xgrammar::CompiledGrammar& grammar) grammar_ = std::make_shared(grammar); } +void ModelRequest::clearGrammar() +{ + grammar_.reset(); +} + } // namespace turbomind diff --git a/src/turbomind/engine/model_request.h b/src/turbomind/engine/model_request.h index b05a980312..19cd27aca8 100644 --- a/src/turbomind/engine/model_request.h +++ b/src/turbomind/engine/model_request.h @@ -41,6 +41,7 @@ class ModelRequest { OutputParam Forward(InputParam param, std::function cb); void setGrammar(const xgrammar::CompiledGrammar& grammar); + void clearGrammar(); protected: Gateway* const gateway_; diff --git a/src/turbomind/python/bind.cpp b/src/turbomind/python/bind.cpp index f4d090fefd..8e2f868ab3 100644 --- a/src/turbomind/python/bind.cpp +++ b/src/turbomind/python/bind.cpp @@ -498,7 +498,14 @@ PYBIND11_MODULE(_turbomind, m) model_request->setGrammar(grammar); }, py::call_guard(), - "grammar"_a); + "grammar"_a) + .def( + "clear_grammar", + [](ModelRequest* model_request) { + TM_LOG_INFO("Release grammar for model_request"); + model_request->clearGrammar(); + }, + py::call_guard()); // transformer model using ft::LlamaTritonModel; From c413dbe92bd7cc6960270903f41fcfb7359612ca Mon Sep 17 00:00:00 2001 From: Zhongbo Tian Date: Thu, 11 Dec 2025 12:21:14 +0800 Subject: [PATCH 5/5] test: add mixing guided and non-guided tests --- tests/test_lmdeploy/test_grammar.py | 55 ++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/test_lmdeploy/test_grammar.py b/tests/test_lmdeploy/test_grammar.py index 9bfe03cec4..35a72d7063 100644 --- a/tests/test_lmdeploy/test_grammar.py +++ b/tests/test_lmdeploy/test_grammar.py @@ -1,3 +1,4 @@ +import asyncio import json import re @@ -5,7 +6,7 @@ from jsonschema import validate from lmdeploy import pipeline -from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, TurbomindEngineConfig +from lmdeploy.messages import GenerationConfig, PytorchEngineConfig, Response, TurbomindEngineConfig MODEL_IDS = [ 'Qwen/Qwen3-0.6B', @@ -95,3 +96,55 @@ def test_guided_matrix(model_id, backend_name, backend_factory, schema_type): assert re.fullmatch(schema, response[0].text) finally: pipe.close() + + +async def collect(*aiters): + results = [[] for _ in range(len(aiters))] + + async def drain(idx, aiter): + async for item in aiter: + results[idx].append(item) + + await asyncio.gather(*(drain(idx, aiter) for idx, aiter in enumerate(aiters))) + + responses = [] + for r in results: + resp = Response(text='', input_token_len=0, generate_token_len=0) + responses.append(resp) + for out in r: + resp.text += out.response + resp.input_token_len = out.input_token_len + resp.generate_token_len = out.generate_token_len + resp.finish_reason = out.finish_reason + + return responses + + +@pytest.mark.parametrize('model_id', MODEL_IDS) +@pytest.mark.parametrize('backend_name,backend_factory', BACKEND_FACTORIES) +def test_mix_guided_matrix(model_id, backend_name, backend_factory): + pipe = pipeline( + model_id, + backend_config=backend_factory(), + log_level='INFO', + ) + + schema_type = 'json_schema' + response_format = {'type': schema_type} + schema = SCHEMA_MAP[schema_type] + response_format[schema_type] = dict(name='test', schema=schema) + + gen_config = GenerationConfig(response_format=response_format) + + configs = [None if idx % 3 else gen_config for idx in range(4)] + tasks = [ + pipe.generate(messages='Make a self introduction please.', session_id=session_id, gen_config=gen_config) + for session_id, gen_config in enumerate(configs) + ] + + responses = asyncio.run(collect(*tasks)) + for resp, config in zip(responses, configs): + if config is None: + assert '}' not in resp.text + else: + validate(instance=json.loads(resp.text), schema=schema)