Skip to content

Commit 7b8f60c

Browse files
committed
refactor: separate mlu and cuda version Qwen model implementation.
1 parent 4d97206 commit 7b8f60c

32 files changed

+3575
-1209
lines changed

xllm/core/layers/common/indexer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ limitations under the License.
2525
#include "../mlu/attention.h"
2626
#elif defined(USE_CUDA)
2727
#include "../cuda/attention.h"
28-
#endif #include "framework/kv_cache/kv_cache.h"
28+
#endif
2929
#include "framework/model/model_input_params.h"
3030
#include "framework/parallel_state/parallel_args.h"
3131
#include "framework/quant_args.h"

xllm/models/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
include(cc_library)
22

3-
# Define the library
43
cc_library(
54
NAME
65
models

xllm/models/llm/deepseek_v2.h

Lines changed: 54 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
1615
#pragma once
1716

18-
#include <gflags/gflags.h>
1917
#include <torch/torch.h>
2018

21-
#include <boost/algorithm/string.hpp>
2219
#include <string>
2320
#include <vector>
2421

25-
#include "core/common/global_flags.h"
26-
#include "core/framework/kv_cache/kv_cache.h"
27-
#include "core/framework/model/model_input_params.h"
28-
#include "core/framework/model/npu_dp_ep_padding.h"
29-
#include "core/framework/model_context.h"
30-
#include "core/layers/attention_mask.h"
3122
#include "core/layers/deepseek_v2_decoder_layer.h"
32-
#include "core/layers/lm_head.h"
33-
#include "core/layers/pos_embedding.h"
34-
#include "core/layers/rms_norm.h"
35-
#include "core/layers/rotary_embedding.h"
36-
#include "core/layers/word_embedding.h"
37-
#include "models/model_registry.h"
23+
#include "llm_model_base.h"
24+
3825
// DeepSeek v2 compatible with huggingface weights
3926
// ref to:
4027
// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
@@ -46,47 +33,29 @@ using ISlice = torch::indexing::Slice;
4633

4734
class DeepseekV2DecoderLayerImpl : public torch::nn::Module {
4835
public:
49-
DeepseekV2DecoderLayerImpl(const ModelContext& context,
50-
const int32_t i,
51-
const float sm_scale) {
36+
DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) {
5237
// register submodules
53-
decoder_layer_ = register_module(
54-
"decoder_layer", layer::DeepseekV2DecoderLayer(context, i, sm_scale));
38+
decoder_layer_ = register_module("decoder_layer",
39+
layer::DeepseekV2DecoderLayer(context, i));
5540
}
5641

5742
torch::Tensor forward(torch::Tensor& x,
58-
torch::Tensor& cos_pos,
59-
torch::Tensor& sin_pos,
60-
torch::Tensor& attn_mask,
43+
torch::Tensor& positions,
44+
const layer::AttentionMetadata& attn_metadata,
6145
KVCache& kv_cache,
62-
const ModelInputParams& input_params,
63-
aclrtEvent* event,
64-
std::atomic<bool>* event_flag) {
65-
return decoder_layer_(x,
66-
cos_pos,
67-
sin_pos,
68-
attn_mask,
69-
kv_cache,
70-
input_params,
71-
event,
72-
event_flag);
46+
const ModelInputParams& input_params) {
47+
return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params);
7348
}
7449

7550
void load_state_dict(const StateDict& state_dict) {
7651
decoder_layer_->load_state_dict(state_dict);
7752
}
7853

79-
void verify_loaded_weights(const std::string& prefix) const {
80-
decoder_layer_->verify_loaded_weights(prefix);
81-
}
82-
83-
void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); }
84-
85-
void prepare_expert_weight(const std::vector<int32_t>& expert_list) {
86-
decoder_layer_->prepare_expert_weight(expert_list);
54+
virtual void prepare_expert_weight(int32_t layer_id,
55+
const std::vector<int32_t>& expert_ids) {
56+
return;
8757
}
88-
89-
void update_expert_weight() { decoder_layer_->update_expert_weight(); }
58+
virtual void update_expert_weight(int32_t layer_id) { return; }
9059

9160
private:
9261
layer::DeepseekV2DecoderLayer decoder_layer_{nullptr};
@@ -95,114 +64,71 @@ TORCH_MODULE(DeepseekV2DecoderLayer);
9564

9665
class DeepseekV2ModelImpl : public torch::nn::Module {
9766
public:
98-
DeepseekV2ModelImpl(const ModelContext& context)
99-
: device_(context.get_tensor_options().device()) {
67+
DeepseekV2ModelImpl(const ModelContext& context) {
10068
auto options = context.get_tensor_options();
10169
auto model_args = context.get_model_args();
10270
auto parallel_args = context.get_parallel_args();
10371

10472
blocks_ = register_module("layers", torch::nn::ModuleList());
10573
layers_.reserve(model_args.n_layers());
74+
10675
// register submodules
107-
device_ = options.device();
108-
dtype_ = options.dtype().toScalarType();
10976
num_speculative_tokens_ = model_args.num_speculative_tokens();
11077

111-
// rotary positional embedding
112-
auto inv_freq = rotary::apply_deepseek_yarn_rope_scaling(
113-
model_args.rope_scaling_factor(),
114-
model_args.rope_extrapolation_factor(),
115-
model_args.rope_scaling_beta_fast(),
116-
model_args.rope_scaling_beta_slow(),
117-
model_args.rotary_dim(),
118-
model_args.rope_theta(),
119-
model_args.rope_scaling_original_max_position_embeddings());
120-
embed_tokens_ =
121-
register_module("embed_tokens", layer::WordEmbedding(context));
122-
float sm_scale = 1.0f;
123-
pos_emb_ = create_rotary_embedding(model_args,
124-
model_args.rotary_dim(),
125-
inv_freq,
126-
/*interleaved=*/false,
127-
sm_scale,
128-
options);
129-
atb_pos_emb_ = layer::PosEmbedding(context);
130-
131-
max_seq_len_ = model_args.max_position_embeddings();
132-
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
133-
attn_mask_ = layer::AttentionMask(options.device(),
134-
options.dtype().toScalarType(),
135-
/*mask_value=*/mask_value);
78+
// MTP is not support for now
79+
if (num_speculative_tokens_ > 0) {
80+
LOG(FATAL) << "DeepSeek MTP on MLU is not support for now";
81+
}
13682

83+
embed_tokens_ =
84+
register_module("embed_tokens",
85+
layer::WordEmbedding(model_args.vocab_size(),
86+
model_args.hidden_size(),
87+
context.get_parallel_args(),
88+
options));
89+
norm_ = register_module(
90+
"norm",
91+
layer::RmsNorm(
92+
model_args.hidden_size(), model_args.rms_norm_eps(), options));
93+
94+
// create decoder layers
13795
for (int32_t i = 0; i < model_args.n_layers(); ++i) {
138-
auto block = DeepseekV2DecoderLayer(context, i, sm_scale);
96+
auto block = DeepseekV2DecoderLayer(context, i);
13997
layers_.push_back(block);
14098
blocks_->push_back(block);
14199
}
142100

143-
norm_ = register_module("norm", layer::RmsNorm(context));
144-
// dp_size_=4;
145101
dp_size_ = parallel_args.dp_size();
146102
std::vector<int64_t> indices;
147103
dp_local_tp_size_ = parallel_args.world_size() / dp_size_;
148104
dp_rank_ = parallel_args.rank() / dp_local_tp_size_;
149105
rank_ = parallel_args.rank();
150-
mapping_data_ = parallel_args.mapping_data();
151-
num_experts_per_tok_ = model_args.num_experts_per_tok();
152106
for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) {
153107
indices.push_back(i);
154108
}
155109
}
156110

157-
torch::Tensor forward(torch::Tensor tokens,
158-
torch::Tensor positions,
159-
std::vector<KVCache>& kv_caches,
160-
const ModelInputParams& input_params) {
161-
if (dp_size_ > 1) {
162-
if (tokens.sizes() == 0) {
163-
tokens = torch::tensor({1}).to(torch::kInt32).to(device_);
164-
positions = torch::tensor({0}).to(torch::kInt32).to(device_);
165-
}
166-
}
167-
168-
auto h = embed_tokens_(tokens, 0);
169-
auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0);
170-
auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1);
171-
auto cos_pos = cos_sin_chunks[0].contiguous();
172-
auto sin_pos = cos_sin_chunks[1].contiguous();
173-
174-
torch::Tensor attn_mask;
175-
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
176-
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
177-
} else {
178-
attn_mask = attn_mask_.gen_free_mask(
179-
num_speculative_tokens_ + 1, dtype_, device_);
180-
}
181-
111+
torch::Tensor forward_native(torch::Tensor tokens,
112+
torch::Tensor positions,
113+
std::vector<KVCache>& kv_caches,
114+
const ModelInputParams& input_params) {
115+
bool is_prefill = input_params.q_max_seq_len > 1;
116+
auto attn_metadata =
117+
layer::AttentionMetadata::build(input_params, is_prefill);
118+
torch::Tensor h = embed_tokens_(tokens);
182119
for (size_t i = 0; i < layers_.size(); i++) {
183-
aclrtEvent* event = nullptr;
184-
std::atomic<bool>* event_flag = nullptr;
185-
if (input_params.layer_synchronizer != nullptr) {
186-
event = input_params.layer_synchronizer->get_event(i);
187-
event_flag = input_params.layer_synchronizer->get_event_flag(i);
188-
}
189-
if (input_params.layer_wise_load_synchronizer != nullptr) {
190-
if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) {
191-
return torch::Tensor();
192-
}
193-
}
194-
195120
auto& layer = layers_[i];
196-
layer(h,
197-
cos_pos,
198-
sin_pos,
199-
attn_mask,
200-
kv_caches[i],
201-
input_params,
202-
event,
203-
event_flag);
121+
h = layer(h, positions, attn_metadata, kv_caches[i], input_params);
204122
}
205-
return norm_(h, 0);
123+
return norm_(h);
124+
}
125+
126+
// Provide batched signature to satisfy callers that pass vectors
127+
torch::Tensor forward(const torch::Tensor& tokens,
128+
const torch::Tensor& positions,
129+
std::vector<KVCache>& kv_caches,
130+
const ModelInputParams& input_params) {
131+
return forward_native(tokens, positions, kv_caches, input_params);
206132
}
207133

208134
// load the weight from the checkpoint
@@ -217,32 +143,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
217143
norm_->load_state_dict(state_dict.get_dict_with_prefix("norm."));
218144
}
219145

220-
void verify_loaded_weights(const std::string& prefix) const {
221-
embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
222-
for (int i = 0; i < layers_.size(); i++) {
223-
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
224-
".");
225-
}
226-
norm_->verify_loaded_weights(prefix + "norm.");
227-
}
228-
229-
void merge_loaded_weights() {
230-
embed_tokens_->merge_loaded_weights();
231-
for (int i = 0; i < layers_.size(); i++) {
232-
layers_[i]->merge_loaded_weights();
233-
}
234-
norm_->merge_loaded_weights();
235-
}
236-
237-
void prepare_expert_weight(int32_t layer_id,
238-
const std::vector<int32_t>& expert_ids) {
239-
layers_[layer_id]->prepare_expert_weight(expert_ids);
240-
}
241-
242-
void update_expert_weight(int32_t layer_id) {
243-
layers_[layer_id]->update_expert_weight();
244-
}
245-
246146
layer::WordEmbedding get_word_embedding() { return embed_tokens_; }
247147

248148
void set_word_embedding(layer::WordEmbedding& word_embedding) {
@@ -252,90 +152,21 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
252152
private:
253153
torch::nn::ModuleList blocks_{nullptr};
254154
std::vector<DeepseekV2DecoderLayer> layers_;
255-
int32_t max_seq_len_ = 0;
256155
int32_t dp_rank_;
257156
int32_t rank_;
258157
int32_t dp_size_;
259158
int32_t dp_local_tp_size_;
260-
nlohmann::json mapping_data_;
261-
int32_t num_experts_per_tok_;
262159
int32_t num_speculative_tokens_ = 0;
263-
at::Device device_;
264-
torch::Dtype dtype_;
265160
layer::WordEmbedding embed_tokens_{nullptr};
266-
std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr};
267-
layer::PosEmbedding atb_pos_emb_{nullptr};
268-
layer::AttentionMask attn_mask_;
269161
layer::RmsNorm norm_{nullptr};
270162
};
271163
TORCH_MODULE(DeepseekV2Model);
272164

273-
class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
165+
class DeepseekV2ForCausalLMImpl
166+
: public LlmForCausalLMImplBase<DeepseekV2Model> {
274167
public:
275-
DeepseekV2ForCausalLMImpl(const ModelContext& context) {
276-
model_ = register_module("model", DeepseekV2Model(context));
277-
lm_head_ = register_module("lm_head", layer::LmHead(context));
278-
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
279-
}
280-
281-
// tokens: [num_tokens]
282-
// positions: [num_tokens] token pos in the sequence
283-
// returns: [num_tokens, hidden_size]
284-
torch::Tensor forward(const torch::Tensor& tokens,
285-
const torch::Tensor& positions,
286-
std::vector<KVCache>& kv_caches,
287-
const ModelInputParams& input_params) {
288-
return model_(tokens, positions, kv_caches, input_params);
289-
}
290-
291-
// hidden_states: [num_tokens, hidden_size]
292-
// seleted_idxes: [num_tokens]
293-
// returns: [num_tokens, vocab_size]
294-
torch::Tensor logits(const torch::Tensor& hidden_states,
295-
const torch::Tensor& seleted_idxes) {
296-
return lm_head_(hidden_states, seleted_idxes, 0);
297-
}
298-
299-
void load_model(std::unique_ptr<ModelLoader> loader) {
300-
for (const auto& state_dict : loader->get_state_dicts()) {
301-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
302-
lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head."));
303-
}
304-
305-
// verify
306-
model_->verify_loaded_weights("model.");
307-
lm_head_->verify_loaded_weights("lm_head.");
308-
309-
model_->merge_loaded_weights();
310-
lm_head_->merge_loaded_weights();
311-
}
312-
313-
void prepare_expert_weight(int32_t layer_id,
314-
const std::vector<int32_t>& expert_ids) {
315-
model_->prepare_expert_weight(layer_id + first_k_dense_replace_,
316-
expert_ids);
317-
}
318-
319-
void update_expert_weight(int32_t layer_id) {
320-
model_->update_expert_weight(layer_id + first_k_dense_replace_);
321-
}
322-
323-
layer::LmHead get_lm_head() { return lm_head_; }
324-
325-
void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
326-
327-
layer::WordEmbedding get_word_embedding() {
328-
return model_->get_word_embedding();
329-
}
330-
331-
void set_word_embedding(layer::WordEmbedding& word_embedding) {
332-
model_->set_word_embedding(word_embedding);
333-
}
334-
335-
private:
336-
DeepseekV2Model model_{nullptr};
337-
layer::LmHead lm_head_{nullptr};
338-
int32_t first_k_dense_replace_;
168+
DeepseekV2ForCausalLMImpl(const ModelContext& context)
169+
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
339170
};
340171
TORCH_MODULE(DeepseekV2ForCausalLM);
341172

0 commit comments

Comments
 (0)