@@ -12,29 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
15-
1615#pragma once
1716
18- #include < gflags/gflags.h>
1917#include < torch/torch.h>
2018
21- #include < boost/algorithm/string.hpp>
2219#include < string>
2320#include < vector>
2421
25- #include " core/common/global_flags.h"
26- #include " core/framework/kv_cache/kv_cache.h"
27- #include " core/framework/model/model_input_params.h"
28- #include " core/framework/model/npu_dp_ep_padding.h"
29- #include " core/framework/model_context.h"
30- #include " core/layers/attention_mask.h"
3122#include " core/layers/deepseek_v2_decoder_layer.h"
32- #include " core/layers/lm_head.h"
33- #include " core/layers/pos_embedding.h"
34- #include " core/layers/rms_norm.h"
35- #include " core/layers/rotary_embedding.h"
36- #include " core/layers/word_embedding.h"
37- #include " models/model_registry.h"
23+ #include " llm_model_base.h"
24+
3825// DeepSeek v2 compatible with huggingface weights
3926// ref to:
4027// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
@@ -46,47 +33,29 @@ using ISlice = torch::indexing::Slice;
4633
4734class DeepseekV2DecoderLayerImpl : public torch ::nn::Module {
4835 public:
49- DeepseekV2DecoderLayerImpl (const ModelContext& context,
50- const int32_t i,
51- const float sm_scale) {
36+ DeepseekV2DecoderLayerImpl (const ModelContext& context, const int32_t i) {
5237 // register submodules
53- decoder_layer_ = register_module (
54- " decoder_layer " , layer::DeepseekV2DecoderLayer (context, i, sm_scale ));
38+ decoder_layer_ = register_module (" decoder_layer " ,
39+ layer::DeepseekV2DecoderLayer (context, i));
5540 }
5641
5742 torch::Tensor forward (torch::Tensor& x,
58- torch::Tensor& cos_pos,
59- torch::Tensor& sin_pos,
60- torch::Tensor& attn_mask,
43+ torch::Tensor& positions,
44+ const layer::AttentionMetadata& attn_metadata,
6145 KVCache& kv_cache,
62- const ModelInputParams& input_params,
63- aclrtEvent* event,
64- std::atomic<bool >* event_flag) {
65- return decoder_layer_ (x,
66- cos_pos,
67- sin_pos,
68- attn_mask,
69- kv_cache,
70- input_params,
71- event,
72- event_flag);
46+ const ModelInputParams& input_params) {
47+ return decoder_layer_ (x, positions, attn_metadata, kv_cache, input_params);
7348 }
7449
7550 void load_state_dict (const StateDict& state_dict) {
7651 decoder_layer_->load_state_dict (state_dict);
7752 }
7853
79- void verify_loaded_weights (const std::string& prefix) const {
80- decoder_layer_->verify_loaded_weights (prefix);
81- }
82-
83- void merge_loaded_weights () { decoder_layer_->merge_loaded_weights (); }
84-
85- void prepare_expert_weight (const std::vector<int32_t >& expert_list) {
86- decoder_layer_->prepare_expert_weight (expert_list);
54+ virtual void prepare_expert_weight (int32_t layer_id,
55+ const std::vector<int32_t >& expert_ids) {
56+ return ;
8757 }
88-
89- void update_expert_weight () { decoder_layer_->update_expert_weight (); }
58+ virtual void update_expert_weight (int32_t layer_id) { return ; }
9059
9160 private:
9261 layer::DeepseekV2DecoderLayer decoder_layer_{nullptr };
@@ -95,114 +64,71 @@ TORCH_MODULE(DeepseekV2DecoderLayer);
9564
9665class DeepseekV2ModelImpl : public torch ::nn::Module {
9766 public:
98- DeepseekV2ModelImpl (const ModelContext& context)
99- : device_(context.get_tensor_options().device()) {
67+ DeepseekV2ModelImpl (const ModelContext& context) {
10068 auto options = context.get_tensor_options ();
10169 auto model_args = context.get_model_args ();
10270 auto parallel_args = context.get_parallel_args ();
10371
10472 blocks_ = register_module (" layers" , torch::nn::ModuleList ());
10573 layers_.reserve (model_args.n_layers ());
74+
10675 // register submodules
107- device_ = options.device ();
108- dtype_ = options.dtype ().toScalarType ();
10976 num_speculative_tokens_ = model_args.num_speculative_tokens ();
11077
111- // rotary positional embedding
112- auto inv_freq = rotary::apply_deepseek_yarn_rope_scaling (
113- model_args.rope_scaling_factor (),
114- model_args.rope_extrapolation_factor (),
115- model_args.rope_scaling_beta_fast (),
116- model_args.rope_scaling_beta_slow (),
117- model_args.rotary_dim (),
118- model_args.rope_theta (),
119- model_args.rope_scaling_original_max_position_embeddings ());
120- embed_tokens_ =
121- register_module (" embed_tokens" , layer::WordEmbedding (context));
122- float sm_scale = 1 .0f ;
123- pos_emb_ = create_rotary_embedding (model_args,
124- model_args.rotary_dim (),
125- inv_freq,
126- /* interleaved=*/ false ,
127- sm_scale,
128- options);
129- atb_pos_emb_ = layer::PosEmbedding (context);
130-
131- max_seq_len_ = model_args.max_position_embeddings ();
132- int32_t mask_value = model_args.dtype () == " bfloat16" ? 1 : -9984 ;
133- attn_mask_ = layer::AttentionMask (options.device (),
134- options.dtype ().toScalarType (),
135- /* mask_value=*/ mask_value);
78+ // MTP is not support for now
79+ if (num_speculative_tokens_ > 0 ) {
80+ LOG (FATAL) << " DeepSeek MTP on MLU is not support for now" ;
81+ }
13682
83+ embed_tokens_ =
84+ register_module (" embed_tokens" ,
85+ layer::WordEmbedding (model_args.vocab_size (),
86+ model_args.hidden_size (),
87+ context.get_parallel_args (),
88+ options));
89+ norm_ = register_module (
90+ " norm" ,
91+ layer::RmsNorm (
92+ model_args.hidden_size (), model_args.rms_norm_eps (), options));
93+
94+ // create decoder layers
13795 for (int32_t i = 0 ; i < model_args.n_layers (); ++i) {
138- auto block = DeepseekV2DecoderLayer (context, i, sm_scale );
96+ auto block = DeepseekV2DecoderLayer (context, i);
13997 layers_.push_back (block);
14098 blocks_->push_back (block);
14199 }
142100
143- norm_ = register_module (" norm" , layer::RmsNorm (context));
144- // dp_size_=4;
145101 dp_size_ = parallel_args.dp_size ();
146102 std::vector<int64_t > indices;
147103 dp_local_tp_size_ = parallel_args.world_size () / dp_size_;
148104 dp_rank_ = parallel_args.rank () / dp_local_tp_size_;
149105 rank_ = parallel_args.rank ();
150- mapping_data_ = parallel_args.mapping_data ();
151- num_experts_per_tok_ = model_args.num_experts_per_tok ();
152106 for (int i = 0 ; i < parallel_args.world_size (); i += dp_local_tp_size_) {
153107 indices.push_back (i);
154108 }
155109 }
156110
157- torch::Tensor forward (torch::Tensor tokens,
158- torch::Tensor positions,
159- std::vector<KVCache>& kv_caches,
160- const ModelInputParams& input_params) {
161- if (dp_size_ > 1 ) {
162- if (tokens.sizes () == 0 ) {
163- tokens = torch::tensor ({1 }).to (torch::kInt32 ).to (device_);
164- positions = torch::tensor ({0 }).to (torch::kInt32 ).to (device_);
165- }
166- }
167-
168- auto h = embed_tokens_ (tokens, 0 );
169- auto cos_sin = atb_pos_emb_ (pos_emb_->get_cos_sin_cache (), positions, 0 );
170- auto cos_sin_chunks = cos_sin.chunk (/* chunks=*/ 2 , /* dim=*/ -1 );
171- auto cos_pos = cos_sin_chunks[0 ].contiguous ();
172- auto sin_pos = cos_sin_chunks[1 ].contiguous ();
173-
174- torch::Tensor attn_mask;
175- if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache ) {
176- attn_mask = attn_mask_.get_attn_mask (128 , dtype_, device_);
177- } else {
178- attn_mask = attn_mask_.gen_free_mask (
179- num_speculative_tokens_ + 1 , dtype_, device_);
180- }
181-
111+ torch::Tensor forward_native (torch::Tensor tokens,
112+ torch::Tensor positions,
113+ std::vector<KVCache>& kv_caches,
114+ const ModelInputParams& input_params) {
115+ bool is_prefill = input_params.q_max_seq_len > 1 ;
116+ auto attn_metadata =
117+ layer::AttentionMetadata::build (input_params, is_prefill);
118+ torch::Tensor h = embed_tokens_ (tokens);
182119 for (size_t i = 0 ; i < layers_.size (); i++) {
183- aclrtEvent* event = nullptr ;
184- std::atomic<bool >* event_flag = nullptr ;
185- if (input_params.layer_synchronizer != nullptr ) {
186- event = input_params.layer_synchronizer ->get_event (i);
187- event_flag = input_params.layer_synchronizer ->get_event_flag (i);
188- }
189- if (input_params.layer_wise_load_synchronizer != nullptr ) {
190- if (!input_params.layer_wise_load_synchronizer ->synchronize_layer (i)) {
191- return torch::Tensor ();
192- }
193- }
194-
195120 auto & layer = layers_[i];
196- layer (h,
197- cos_pos,
198- sin_pos,
199- attn_mask,
200- kv_caches[i],
201- input_params,
202- event,
203- event_flag);
121+ h = layer (h, positions, attn_metadata, kv_caches[i], input_params);
204122 }
205- return norm_ (h, 0 );
123+ return norm_ (h);
124+ }
125+
126+ // Provide batched signature to satisfy callers that pass vectors
127+ torch::Tensor forward (const torch::Tensor& tokens,
128+ const torch::Tensor& positions,
129+ std::vector<KVCache>& kv_caches,
130+ const ModelInputParams& input_params) {
131+ return forward_native (tokens, positions, kv_caches, input_params);
206132 }
207133
208134 // load the weight from the checkpoint
@@ -217,32 +143,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
217143 norm_->load_state_dict (state_dict.get_dict_with_prefix (" norm." ));
218144 }
219145
220- void verify_loaded_weights (const std::string& prefix) const {
221- embed_tokens_->verify_loaded_weights (prefix + " embed_tokens." );
222- for (int i = 0 ; i < layers_.size (); i++) {
223- layers_[i]->verify_loaded_weights (prefix + " layers." + std::to_string (i) +
224- " ." );
225- }
226- norm_->verify_loaded_weights (prefix + " norm." );
227- }
228-
229- void merge_loaded_weights () {
230- embed_tokens_->merge_loaded_weights ();
231- for (int i = 0 ; i < layers_.size (); i++) {
232- layers_[i]->merge_loaded_weights ();
233- }
234- norm_->merge_loaded_weights ();
235- }
236-
237- void prepare_expert_weight (int32_t layer_id,
238- const std::vector<int32_t >& expert_ids) {
239- layers_[layer_id]->prepare_expert_weight (expert_ids);
240- }
241-
242- void update_expert_weight (int32_t layer_id) {
243- layers_[layer_id]->update_expert_weight ();
244- }
245-
246146 layer::WordEmbedding get_word_embedding () { return embed_tokens_; }
247147
248148 void set_word_embedding (layer::WordEmbedding& word_embedding) {
@@ -252,90 +152,21 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
252152 private:
253153 torch::nn::ModuleList blocks_{nullptr };
254154 std::vector<DeepseekV2DecoderLayer> layers_;
255- int32_t max_seq_len_ = 0 ;
256155 int32_t dp_rank_;
257156 int32_t rank_;
258157 int32_t dp_size_;
259158 int32_t dp_local_tp_size_;
260- nlohmann::json mapping_data_;
261- int32_t num_experts_per_tok_;
262159 int32_t num_speculative_tokens_ = 0 ;
263- at::Device device_;
264- torch::Dtype dtype_;
265160 layer::WordEmbedding embed_tokens_{nullptr };
266- std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr };
267- layer::PosEmbedding atb_pos_emb_{nullptr };
268- layer::AttentionMask attn_mask_;
269161 layer::RmsNorm norm_{nullptr };
270162};
271163TORCH_MODULE (DeepseekV2Model);
272164
273- class DeepseekV2ForCausalLMImpl : public torch ::nn::Module {
165+ class DeepseekV2ForCausalLMImpl
166+ : public LlmForCausalLMImplBase<DeepseekV2Model> {
274167 public:
275- DeepseekV2ForCausalLMImpl (const ModelContext& context) {
276- model_ = register_module (" model" , DeepseekV2Model (context));
277- lm_head_ = register_module (" lm_head" , layer::LmHead (context));
278- first_k_dense_replace_ = context.get_model_args ().first_k_dense_replace ();
279- }
280-
281- // tokens: [num_tokens]
282- // positions: [num_tokens] token pos in the sequence
283- // returns: [num_tokens, hidden_size]
284- torch::Tensor forward (const torch::Tensor& tokens,
285- const torch::Tensor& positions,
286- std::vector<KVCache>& kv_caches,
287- const ModelInputParams& input_params) {
288- return model_ (tokens, positions, kv_caches, input_params);
289- }
290-
291- // hidden_states: [num_tokens, hidden_size]
292- // seleted_idxes: [num_tokens]
293- // returns: [num_tokens, vocab_size]
294- torch::Tensor logits (const torch::Tensor& hidden_states,
295- const torch::Tensor& seleted_idxes) {
296- return lm_head_ (hidden_states, seleted_idxes, 0 );
297- }
298-
299- void load_model (std::unique_ptr<ModelLoader> loader) {
300- for (const auto & state_dict : loader->get_state_dicts ()) {
301- model_->load_state_dict (state_dict->get_dict_with_prefix (" model." ));
302- lm_head_->load_state_dict (state_dict->get_dict_with_prefix (" lm_head." ));
303- }
304-
305- // verify
306- model_->verify_loaded_weights (" model." );
307- lm_head_->verify_loaded_weights (" lm_head." );
308-
309- model_->merge_loaded_weights ();
310- lm_head_->merge_loaded_weights ();
311- }
312-
313- void prepare_expert_weight (int32_t layer_id,
314- const std::vector<int32_t >& expert_ids) {
315- model_->prepare_expert_weight (layer_id + first_k_dense_replace_,
316- expert_ids);
317- }
318-
319- void update_expert_weight (int32_t layer_id) {
320- model_->update_expert_weight (layer_id + first_k_dense_replace_);
321- }
322-
323- layer::LmHead get_lm_head () { return lm_head_; }
324-
325- void set_lm_head (layer::LmHead& head) { lm_head_ = head; }
326-
327- layer::WordEmbedding get_word_embedding () {
328- return model_->get_word_embedding ();
329- }
330-
331- void set_word_embedding (layer::WordEmbedding& word_embedding) {
332- model_->set_word_embedding (word_embedding);
333- }
334-
335- private:
336- DeepseekV2Model model_{nullptr };
337- layer::LmHead lm_head_{nullptr };
338- int32_t first_k_dense_replace_;
168+ DeepseekV2ForCausalLMImpl (const ModelContext& context)
169+ : LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
339170};
340171TORCH_MODULE (DeepseekV2ForCausalLM);
341172
0 commit comments