Skip to content

Commit f90f5d3

Browse files
committed
refactor: separate mlu and cuda version Qwen model implementation.
1 parent 4d97206 commit f90f5d3

File tree

16 files changed

+2585
-107
lines changed

16 files changed

+2585
-107
lines changed

xllm/core/layers/common/indexer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ limitations under the License.
2525
#include "../mlu/attention.h"
2626
#elif defined(USE_CUDA)
2727
#include "../cuda/attention.h"
28-
#endif #include "framework/kv_cache/kv_cache.h"
28+
#endif
2929
#include "framework/model/model_input_params.h"
3030
#include "framework/parallel_state/parallel_args.h"
3131
#include "framework/quant_args.h"

xllm/models/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
include(cc_library)
22

3-
# Define the library
43
cc_library(
54
NAME
65
models

xllm/models/llm/mlu/deepseek_v2.h renamed to xllm/models/llm/common/deepseek_v2.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ limitations under the License.
2020
#include <vector>
2121

2222
#include "core/layers/deepseek_v2_decoder_layer.h"
23-
#include "models/llm/llm_model_base.h"
23+
#include "llm_model_base.h"
2424

2525
// DeepSeek v2 compatible with huggingface weights
2626
// ref to:
File renamed without changes.
File renamed without changes.
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <torch/torch.h>
19+
20+
#include <string>
21+
#include <typeinfo>
22+
#include <vector>
23+
24+
#include "core/common/global_flags.h"
25+
#include "core/common/interruption_bus.h"
26+
#include "core/framework/kv_cache/kv_cache.h"
27+
#include "core/framework/model/model_input_params.h"
28+
#include "core/framework/model_context.h"
29+
#include "core/layers/attention_mask.h"
30+
#include "core/layers/common/layer_utils.h"
31+
#include "core/layers/lm_head.h"
32+
#include "core/layers/rms_norm.h"
33+
#include "models/model_registry.h"
34+
#if defined(USE_CUDA)
35+
#include "core/layers/cuda/attention.h"
36+
#endif
37+
#if defined(USE_MLU)
38+
#include "core/layers/mlu/attention.h"
39+
#endif
40+
41+
namespace xllm {
42+
43+
template <typename DecoderType>
44+
class LlmDecoderLayerImplBase : public torch::nn::Module {
45+
public:
46+
LlmDecoderLayerImplBase(const ModelContext& context) {
47+
// register submodules
48+
decoder_layer_ = register_module("decoder_layer", DecoderType(context));
49+
}
50+
51+
virtual torch::Tensor forward(torch::Tensor& x,
52+
torch::Tensor& positions,
53+
const layer::AttentionMetadata& attn_metadata,
54+
KVCache& kv_cache,
55+
const ModelInputParams& input_params) {
56+
return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params);
57+
}
58+
59+
// load the weight from the checkpoint
60+
virtual void load_state_dict(const StateDict& state_dict) {
61+
// call each submodule's load_state_dict function
62+
decoder_layer_->load_state_dict(state_dict);
63+
}
64+
65+
private:
66+
DecoderType decoder_layer_{nullptr};
67+
};
68+
69+
template <typename DecoderLayerType>
70+
class LlmModelImplBase : public torch::nn::Module {
71+
public:
72+
// mode type: qwen2, qwen3 .etc
73+
LlmModelImplBase(const std::string& model_type, const ModelArgs& args)
74+
: model_type_(model_type) {
75+
InterruptionBus::get_instance().subscribe([this](bool interrupted) {
76+
this->layer_forward_interrupted_ = interrupted;
77+
});
78+
mrope_section_ = args.rope_scaling_mrope_section();
79+
}
80+
81+
torch::Tensor get_input_embeddings(torch::Tensor input_ids) {
82+
return embed_tokens_(input_ids);
83+
}
84+
85+
// tokens: [num_tokens]
86+
// positions: [num_tokens] token pos in the sequence
87+
virtual torch::Tensor forward(torch::Tensor tokens,
88+
torch::Tensor positions,
89+
std::vector<KVCache>& kv_caches,
90+
const ModelInputParams& input_params) {
91+
if (tokens.numel() == 0) {
92+
tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device());
93+
positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device());
94+
}
95+
auto inputs_embeds = input_params.input_embedding;
96+
// test
97+
torch::Tensor h;
98+
if (inputs_embeds.defined()) {
99+
h = inputs_embeds;
100+
} else {
101+
h = embed_tokens_(tokens);
102+
}
103+
104+
auto modified_input_params = input_params;
105+
auto position = positions;
106+
layer::update_dummy_run_input(dp_rank_, position, modified_input_params);
107+
bool is_prefill = modified_input_params.q_max_seq_len > 1;
108+
auto attn_metadata =
109+
layer::AttentionMetadata::build(modified_input_params, is_prefill);
110+
111+
torch::Tensor h_ret;
112+
for (size_t i = 0; i < layers_.size(); i++) {
113+
auto& layer = layers_[i];
114+
h_ret = layer(
115+
h, position, attn_metadata, kv_caches[i], modified_input_params);
116+
}
117+
return norm_(h_ret);
118+
}
119+
120+
// load the weight from the checkpoint
121+
virtual void load_state_dict(const StateDict& state_dict) {
122+
embed_tokens_->load_state_dict(
123+
state_dict.get_dict_with_prefix("embed_tokens."));
124+
125+
// call each layer's load_state_dict function
126+
for (int i = 0; i < layers_.size(); i++) {
127+
layers_[i]->load_state_dict(
128+
state_dict.get_dict_with_prefix("layers." + std::to_string(i) + "."));
129+
}
130+
norm_->load_state_dict(state_dict.get_dict_with_prefix("norm."));
131+
}
132+
133+
virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; }
134+
135+
virtual void set_word_embedding(layer::WordEmbedding& word_embedding) {
136+
embed_tokens_ = word_embedding;
137+
}
138+
139+
protected:
140+
int max_seq_len_ = 0;
141+
torch::Tensor cos_pos_;
142+
torch::Tensor sin_pos_;
143+
int device_id = 0;
144+
layer::AttentionMask attn_mask_;
145+
int dp_rank_ = 0;
146+
147+
std::vector<int64_t> mrope_section_;
148+
// test
149+
// ParallelEmbedding embed_tokens_{nullptr};
150+
layer::WordEmbedding embed_tokens_{nullptr};
151+
layer::RmsNorm norm_{nullptr};
152+
153+
torch::nn::ModuleList blocks_{nullptr};
154+
// hold same data but different type as blocks_ to avoid type cast
155+
std::vector<DecoderLayerType> layers_;
156+
157+
bool layer_forward_interrupted_ = false;
158+
159+
private:
160+
std::string model_type_;
161+
};
162+
163+
template <typename LlmModelType>
164+
class LlmForCausalLMImplBase : public torch::nn::Module {
165+
public:
166+
LlmForCausalLMImplBase(const ModelContext& context) {
167+
tie_word_embeddings = context.get_model_args().tie_word_embeddings();
168+
// register submodules
169+
model_ = register_module("model", LlmModelType(context));
170+
171+
lm_head_ = register_module("lm_head", layer::LmHead(context));
172+
}
173+
174+
torch::Tensor get_input_embeddings(torch::Tensor input_ids) {
175+
return model_->get_input_embeddings(input_ids);
176+
}
177+
178+
// tokens: [num_tokens]
179+
// positions: [num_tokens] token pos in the sequence
180+
// returns: [num_tokens, hidden_size]
181+
virtual torch::Tensor forward(const torch::Tensor& tokens,
182+
const torch::Tensor& positions,
183+
std::vector<KVCache>& kv_caches,
184+
const ModelInputParams& input_params) {
185+
return model_(tokens, positions, kv_caches, input_params);
186+
}
187+
188+
// hidden_states: [num_tokens, hidden_size]
189+
// seleted_idxes: [num_tokens]
190+
// returns: [num_tokens, vocab_size]
191+
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
192+
const torch::Tensor& seleted_idxes) {
193+
// select tokens if provided
194+
auto h = hidden_states;
195+
if (seleted_idxes.defined()) {
196+
h = h.index_select(/*dim=*/0, seleted_idxes);
197+
}
198+
return lm_head_(h);
199+
}
200+
201+
void load_model(std::unique_ptr<ModelLoader> loader,
202+
std::string prefix = "model." /*llm model weight prefix*/) {
203+
for (const auto& state_dict : loader->get_state_dicts()) {
204+
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
205+
if (tie_word_embeddings) {
206+
lm_head_->load_state_dict(
207+
state_dict->get_dict_with_prefix(prefix + "embed_tokens."));
208+
} else {
209+
lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head."));
210+
}
211+
}
212+
}
213+
214+
virtual void prepare_expert_weight(int32_t layer_id,
215+
const std::vector<int32_t>& expert_ids) {
216+
return;
217+
}
218+
virtual void update_expert_weight(int32_t layer_id) { return; }
219+
220+
virtual layer::LmHead get_lm_head() { return lm_head_; }
221+
222+
virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
223+
224+
virtual layer::WordEmbedding get_word_embedding() {
225+
return model_->get_word_embedding();
226+
}
227+
228+
virtual void set_word_embedding(layer::WordEmbedding& word_embedding) {
229+
model_->set_word_embedding(word_embedding);
230+
}
231+
232+
protected:
233+
// parameter members, must be registered
234+
LlmModelType model_{nullptr};
235+
int device_id = 0;
236+
bool tie_word_embeddings{false};
237+
layer::LmHead lm_head_{nullptr};
238+
};
239+
240+
} // namespace xllm

xllm/models/llm/common/qwen2.h

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
==============================================================================*/
16+
17+
#pragma once
18+
19+
#include "core/layers/qwen2_decoder_layer.h"
20+
#include "llm_model_base.h"
21+
22+
// QWen2 model compatible with huggingface weights
23+
// ref to:
24+
// https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py
25+
namespace xllm {
26+
27+
class QWen2DecoderLayerImpl
28+
: public LlmDecoderLayerImplBase<layer::Qwen2DecoderLayer> {
29+
public:
30+
QWen2DecoderLayerImpl(const ModelContext& context)
31+
: LlmDecoderLayerImplBase<layer::Qwen2DecoderLayer>(context) {}
32+
};
33+
TORCH_MODULE(QWen2DecoderLayer);
34+
35+
class QWen2ModelImpl : public LlmModelImplBase<QWen2DecoderLayer> {
36+
public:
37+
QWen2ModelImpl(const ModelContext& context)
38+
: LlmModelImplBase<QWen2DecoderLayer>("qwen2", context.get_model_args()) {
39+
// register submodules
40+
auto model_args = context.get_model_args();
41+
auto options = context.get_tensor_options();
42+
auto parallel_args = context.get_parallel_args();
43+
auto dp_local_tp_size =
44+
parallel_args.world_size() / parallel_args.dp_size();
45+
dp_rank_ = parallel_args.rank() / dp_local_tp_size;
46+
47+
blocks_ = register_module("layers", torch::nn::ModuleList());
48+
layers_.reserve(model_args.n_layers());
49+
norm_ = register_module("norm", layer::RmsNorm(context));
50+
embed_tokens_ =
51+
register_module("embed_tokens", layer::WordEmbedding(context));
52+
int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1;
53+
attn_mask_ = layer::AttentionMask(options.device(),
54+
options.dtype().toScalarType(),
55+
/*mask_value=*/mask_value);
56+
57+
for (int32_t i = 0; i < model_args.n_layers(); i++) {
58+
auto block = QWen2DecoderLayer(context);
59+
layers_.push_back(block);
60+
blocks_->push_back(block);
61+
}
62+
}
63+
};
64+
TORCH_MODULE(QWen2Model);
65+
66+
class QWen2ForCausalLMImpl : public LlmForCausalLMImplBase<QWen2Model> {
67+
public:
68+
QWen2ForCausalLMImpl(const ModelContext& context)
69+
: LlmForCausalLMImplBase<QWen2Model>(context) {}
70+
};
71+
TORCH_MODULE(QWen2ForCausalLM);
72+
73+
// register the causal model
74+
REGISTER_CAUSAL_MODEL(qwen2, QWen2ForCausalLM);
75+
76+
// register the model args
77+
// example config:
78+
// https://huggingface.co/Qwen/Qwen2-7B-Instruct/blob/main/config.json
79+
REGISTER_MODEL_ARGS(qwen2, [&] {
80+
LOAD_ARG_OR(model_type, "model_type", "qwen2");
81+
LOAD_ARG_OR(dtype, "torch_dtype", "");
82+
LOAD_ARG_OR(vocab_size, "vocab_size", 152064);
83+
LOAD_ARG_OR(hidden_size, "hidden_size", 3584);
84+
LOAD_ARG_OR(n_layers, "num_hidden_layers", 28);
85+
LOAD_ARG_OR(n_heads, "num_attention_heads", 28);
86+
LOAD_ARG(n_kv_heads, "num_key_value_heads");
87+
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");
88+
LOAD_ARG_OR(attention_bias, "attention_bias", true);
89+
// LOAD_ARG_OR(no_bias, "no_bias", true);
90+
LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944);
91+
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768);
92+
LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6);
93+
LOAD_ARG_OR(eos_token_id, "eos_token_id", 151643);
94+
LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f);
95+
96+
// For Qwen2/2.5 model < 7B, tie_word_embeddings = true
97+
LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false);
98+
99+
LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false);
100+
LOAD_ARG_OR(sliding_window, "sliding_window", 4096);
101+
LOAD_ARG_OR(max_window_layers, "max_window_layers", 28);
102+
103+
LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
104+
return args->hidden_size() / args->n_heads();
105+
});
106+
107+
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({args->eos_token_id()}));
108+
});
109+
110+
} // namespace xllm

0 commit comments

Comments
 (0)