diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index c45788c00..992fc9749 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -25,7 +25,7 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" -#endif #include "framework/kv_cache/kv_cache.h" +#endif #include "framework/model/model_input_params.h" #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" diff --git a/xllm/models/CMakeLists.txt b/xllm/models/CMakeLists.txt index ed638c539..be03d2561 100644 --- a/xllm/models/CMakeLists.txt +++ b/xllm/models/CMakeLists.txt @@ -1,6 +1,5 @@ include(cc_library) -# Define the library cc_library( NAME models diff --git a/xllm/models/llm/deepseek_v2.h b/xllm/models/llm/deepseek_v2.h index 5d85f8af8..2e552e899 100644 --- a/xllm/models/llm/deepseek_v2.h +++ b/xllm/models/llm/deepseek_v2.h @@ -12,29 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #pragma once -#include #include -#include #include #include -#include "core/common/global_flags.h" -#include "core/framework/kv_cache/kv_cache.h" -#include "core/framework/model/model_input_params.h" -#include "core/framework/model/npu_dp_ep_padding.h" -#include "core/framework/model_context.h" -#include "core/layers/attention_mask.h" #include "core/layers/deepseek_v2_decoder_layer.h" -#include "core/layers/lm_head.h" -#include "core/layers/pos_embedding.h" -#include "core/layers/rms_norm.h" -#include "core/layers/rotary_embedding.h" -#include "core/layers/word_embedding.h" -#include "models/model_registry.h" +#include "llm_model_base.h" + // DeepSeek v2 compatible with huggingface weights // ref to: // https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py @@ -46,47 +33,29 @@ using ISlice = torch::indexing::Slice; class DeepseekV2DecoderLayerImpl : public torch::nn::Module { public: - DeepseekV2DecoderLayerImpl(const ModelContext& context, - const int32_t i, - const float sm_scale) { + DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) { // register submodules - decoder_layer_ = register_module( - "decoder_layer", layer::DeepseekV2DecoderLayer(context, i, sm_scale)); + decoder_layer_ = register_module("decoder_layer", + layer::DeepseekV2DecoderLayer(context, i)); } torch::Tensor forward(torch::Tensor& x, - torch::Tensor& cos_pos, - torch::Tensor& sin_pos, - torch::Tensor& attn_mask, + torch::Tensor& positions, + const layer::AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, - aclrtEvent* event, - std::atomic* event_flag) { - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - event, - event_flag); + const ModelInputParams& input_params) { + return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); } void load_state_dict(const StateDict& state_dict) { decoder_layer_->load_state_dict(state_dict); } - void verify_loaded_weights(const std::string& prefix) const { - decoder_layer_->verify_loaded_weights(prefix); - } - - void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); } - - void prepare_expert_weight(const std::vector& expert_list) { - decoder_layer_->prepare_expert_weight(expert_list); + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; } - - void update_expert_weight() { decoder_layer_->update_expert_weight(); } + virtual void update_expert_weight(int32_t layer_id) { return; } private: layer::DeepseekV2DecoderLayer decoder_layer_{nullptr}; @@ -95,114 +64,71 @@ TORCH_MODULE(DeepseekV2DecoderLayer); class DeepseekV2ModelImpl : public torch::nn::Module { public: - DeepseekV2ModelImpl(const ModelContext& context) - : device_(context.get_tensor_options().device()) { + DeepseekV2ModelImpl(const ModelContext& context) { auto options = context.get_tensor_options(); auto model_args = context.get_model_args(); auto parallel_args = context.get_parallel_args(); blocks_ = register_module("layers", torch::nn::ModuleList()); layers_.reserve(model_args.n_layers()); + // register submodules - device_ = options.device(); - dtype_ = options.dtype().toScalarType(); num_speculative_tokens_ = model_args.num_speculative_tokens(); - // rotary positional embedding - auto inv_freq = rotary::apply_deepseek_yarn_rope_scaling( - model_args.rope_scaling_factor(), - model_args.rope_extrapolation_factor(), - model_args.rope_scaling_beta_fast(), - model_args.rope_scaling_beta_slow(), - model_args.rotary_dim(), - model_args.rope_theta(), - model_args.rope_scaling_original_max_position_embeddings()); - embed_tokens_ = - register_module("embed_tokens", layer::WordEmbedding(context)); - float sm_scale = 1.0f; - pos_emb_ = create_rotary_embedding(model_args, - model_args.rotary_dim(), - inv_freq, - /*interleaved=*/false, - sm_scale, - options); - atb_pos_emb_ = layer::PosEmbedding(context); - - max_seq_len_ = model_args.max_position_embeddings(); - int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; - attn_mask_ = layer::AttentionMask(options.device(), - options.dtype().toScalarType(), - /*mask_value=*/mask_value); + // MTP is not support for now + if (num_speculative_tokens_ > 0) { + LOG(FATAL) << "DeepSeek MTP on MLU is not support for now"; + } + embed_tokens_ = + register_module("embed_tokens", + layer::WordEmbedding(model_args.vocab_size(), + model_args.hidden_size(), + context.get_parallel_args(), + options)); + norm_ = register_module( + "norm", + layer::RmsNorm( + model_args.hidden_size(), model_args.rms_norm_eps(), options)); + + // create decoder layers for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = DeepseekV2DecoderLayer(context, i, sm_scale); + auto block = DeepseekV2DecoderLayer(context, i); layers_.push_back(block); blocks_->push_back(block); } - norm_ = register_module("norm", layer::RmsNorm(context)); - // dp_size_=4; dp_size_ = parallel_args.dp_size(); std::vector indices; dp_local_tp_size_ = parallel_args.world_size() / dp_size_; dp_rank_ = parallel_args.rank() / dp_local_tp_size_; rank_ = parallel_args.rank(); - mapping_data_ = parallel_args.mapping_data(); - num_experts_per_tok_ = model_args.num_experts_per_tok(); for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { indices.push_back(i); } } - torch::Tensor forward(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - if (dp_size_ > 1) { - if (tokens.sizes() == 0) { - tokens = torch::tensor({1}).to(torch::kInt32).to(device_); - positions = torch::tensor({0}).to(torch::kInt32).to(device_); - } - } - - auto h = embed_tokens_(tokens, 0); - auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0); - auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = cos_sin_chunks[0].contiguous(); - auto sin_pos = cos_sin_chunks[1].contiguous(); - - torch::Tensor attn_mask; - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { - attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); - } else { - attn_mask = attn_mask_.gen_free_mask( - num_speculative_tokens_ + 1, dtype_, device_); - } - + torch::Tensor forward_native(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + bool is_prefill = input_params.q_max_seq_len > 1; + auto attn_metadata = + layer::AttentionMetadata::build(input_params, is_prefill); + torch::Tensor h = embed_tokens_(tokens); for (size_t i = 0; i < layers_.size(); i++) { - aclrtEvent* event = nullptr; - std::atomic* event_flag = nullptr; - if (input_params.layer_synchronizer != nullptr) { - event = input_params.layer_synchronizer->get_event(i); - event_flag = input_params.layer_synchronizer->get_event_flag(i); - } - if (input_params.layer_wise_load_synchronizer != nullptr) { - if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { - return torch::Tensor(); - } - } - auto& layer = layers_[i]; - layer(h, - cos_pos, - sin_pos, - attn_mask, - kv_caches[i], - input_params, - event, - event_flag); + h = layer(h, positions, attn_metadata, kv_caches[i], input_params); } - return norm_(h, 0); + return norm_(h); + } + + // Provide batched signature to satisfy callers that pass vectors + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return forward_native(tokens, positions, kv_caches, input_params); } // load the weight from the checkpoint @@ -217,32 +143,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - - void merge_loaded_weights() { - embed_tokens_->merge_loaded_weights(); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->merge_loaded_weights(); - } - norm_->merge_loaded_weights(); - } - - void prepare_expert_weight(int32_t layer_id, - const std::vector& expert_ids) { - layers_[layer_id]->prepare_expert_weight(expert_ids); - } - - void update_expert_weight(int32_t layer_id) { - layers_[layer_id]->update_expert_weight(); - } - layer::WordEmbedding get_word_embedding() { return embed_tokens_; } void set_word_embedding(layer::WordEmbedding& word_embedding) { @@ -252,90 +152,21 @@ class DeepseekV2ModelImpl : public torch::nn::Module { private: torch::nn::ModuleList blocks_{nullptr}; std::vector layers_; - int32_t max_seq_len_ = 0; int32_t dp_rank_; int32_t rank_; int32_t dp_size_; int32_t dp_local_tp_size_; - nlohmann::json mapping_data_; - int32_t num_experts_per_tok_; int32_t num_speculative_tokens_ = 0; - at::Device device_; - torch::Dtype dtype_; layer::WordEmbedding embed_tokens_{nullptr}; - std::shared_ptr pos_emb_{nullptr}; - layer::PosEmbedding atb_pos_emb_{nullptr}; - layer::AttentionMask attn_mask_; layer::RmsNorm norm_{nullptr}; }; TORCH_MODULE(DeepseekV2Model); -class DeepseekV2ForCausalLMImpl : public torch::nn::Module { +class DeepseekV2ForCausalLMImpl + : public LlmForCausalLMImplBase { public: - DeepseekV2ForCausalLMImpl(const ModelContext& context) { - model_ = register_module("model", DeepseekV2Model(context)); - lm_head_ = register_module("lm_head", layer::LmHead(context)); - first_k_dense_replace_ = context.get_model_args().first_k_dense_replace(); - } - - // tokens: [num_tokens] - // positions: [num_tokens] token pos in the sequence - // returns: [num_tokens, hidden_size] - torch::Tensor forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return model_(tokens, positions, kv_caches, input_params); - } - - // hidden_states: [num_tokens, hidden_size] - // seleted_idxes: [num_tokens] - // returns: [num_tokens, vocab_size] - torch::Tensor logits(const torch::Tensor& hidden_states, - const torch::Tensor& seleted_idxes) { - return lm_head_(hidden_states, seleted_idxes, 0); - } - - void load_model(std::unique_ptr loader) { - for (const auto& state_dict : loader->get_state_dicts()) { - model_->load_state_dict(state_dict->get_dict_with_prefix("model.")); - lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); - } - - // verify - model_->verify_loaded_weights("model."); - lm_head_->verify_loaded_weights("lm_head."); - - model_->merge_loaded_weights(); - lm_head_->merge_loaded_weights(); - } - - void prepare_expert_weight(int32_t layer_id, - const std::vector& expert_ids) { - model_->prepare_expert_weight(layer_id + first_k_dense_replace_, - expert_ids); - } - - void update_expert_weight(int32_t layer_id) { - model_->update_expert_weight(layer_id + first_k_dense_replace_); - } - - layer::LmHead get_lm_head() { return lm_head_; } - - void set_lm_head(layer::LmHead& head) { lm_head_ = head; } - - layer::WordEmbedding get_word_embedding() { - return model_->get_word_embedding(); - } - - void set_word_embedding(layer::WordEmbedding& word_embedding) { - model_->set_word_embedding(word_embedding); - } - - private: - DeepseekV2Model model_{nullptr}; - layer::LmHead lm_head_{nullptr}; - int32_t first_k_dense_replace_; + DeepseekV2ForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} }; TORCH_MODULE(DeepseekV2ForCausalLM); diff --git a/xllm/models/llm/deepseek_v3.h b/xllm/models/llm/deepseek_v3.h index 922bf1dfd..c96830f4c 100644 --- a/xllm/models/llm/deepseek_v3.h +++ b/xllm/models/llm/deepseek_v3.h @@ -42,6 +42,7 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] { LOAD_ARG_OR(max_window_layers, "max_window_layers", 61); LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 3); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); @@ -52,6 +53,7 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] { LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); LOAD_ARG_OR(n_group, "n_group", 8); LOAD_ARG_OR(topk_group, "topk_group", 4); + LOAD_ARG_OR(scoring_func, "scoring_func", "sigmoid"); LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); diff --git a/xllm/models/llm/mlu/deepseek_v32.h b/xllm/models/llm/deepseek_v32.h similarity index 100% rename from xllm/models/llm/mlu/deepseek_v32.h rename to xllm/models/llm/deepseek_v32.h diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 7b1c1c36c..885942267 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -15,10 +15,6 @@ limitations under the License. #pragma once -#if defined(USE_NPU) -#include -#endif -#include #include #include @@ -31,16 +27,10 @@ limitations under the License. #include "core/framework/model/model_input_params.h" #include "core/framework/model_context.h" #include "core/layers/attention_mask.h" -#include "core/layers/block_copy.h" +#include "core/layers/common/layer_utils.h" #include "core/layers/lm_head.h" -#include "core/layers/pos_embedding.h" #include "core/layers/rms_norm.h" #include "models/model_registry.h" -#if defined(USE_NPU) -#include "xllm_kernels/core/include/atb_speed/log.h" -#else -#include "core/layers/common/layer_utils.h" -#endif #if defined(USE_CUDA) #include "core/layers/cuda/attention.h" #endif @@ -50,86 +40,14 @@ limitations under the License. namespace xllm { -torch::Tensor get_concat_rotary_embedding(int64_t dim, - int64_t seq_len, - double rope_theta, - const torch::TensorOptions& options) { - auto options_new = - torch::device(options.device()).dtype(at::ScalarType::Double); - auto inv_freq = - 1.0 / torch::pow(rope_theta, torch::arange(0, dim, 2, options_new) / dim) - .to(at::ScalarType::Float); - auto seq_idx = torch::arange(seq_len, options_new); - - auto freqs = torch::ger(seq_idx, inv_freq).to(torch::kFloat32); - auto emb = torch::cat({freqs, freqs}, -1); - auto rope_cos = torch::cos(emb); - auto rope_sin = torch::sin(emb); - - auto dtype = options.dtype(); - if (dtype == torch::kFloat16 || dtype == torch::kBFloat16 || - dtype == torch::kInt8) { - if (dtype == torch::kBFloat16) { - rope_cos = rope_cos.to(torch::kBFloat16); - rope_sin = rope_sin.to(torch::kBFloat16); - } else { - rope_cos = rope_cos.to(torch::kFloat16); - rope_sin = rope_sin.to(torch::kFloat16); - } - } - std::vector cos_sin{rope_cos, rope_sin}; - return torch::cat(cos_sin, -1); -} - template class LlmDecoderLayerImplBase : public torch::nn::Module { public: LlmDecoderLayerImplBase(const ModelContext& context) { // register submodules decoder_layer_ = register_module("decoder_layer", DecoderType(context)); -#if defined(USE_NPU) - block_copy_ = register_module("block_copy", layer::BlockCopy(context)); -#endif } -#if defined(USE_NPU) - virtual torch::Tensor forward(torch::Tensor& x, - torch::Tensor& cos_pos, - torch::Tensor& sin_pos, - torch::Tensor& attn_mask, - KVCache& kv_cache, - ModelInputParams& input_params, - int node_id, - aclrtEvent* event, - std::atomic* event_flag) { - if (input_params.src_block_indices.numel() > 0) { - block_copy_(kv_cache.get_k_cache(), - kv_cache.get_v_cache(), - input_params.src_block_indices, - input_params.dst_block_indices, - input_params.cum_sum, - 0); - } - - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - event, - event_flag, - node_id); - } - - virtual void verify_loaded_weights(const std::string& prefix) const { - decoder_layer_->verify_loaded_weights(); - } - virtual void merge_loaded_weights() { - decoder_layer_->merge_loaded_weights(); - block_copy_->merge_loaded_weights(); - } -#else virtual torch::Tensor forward(torch::Tensor& x, torch::Tensor& positions, const layer::AttentionMetadata& attn_metadata, @@ -137,7 +55,6 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { const ModelInputParams& input_params) { return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); } -#endif // load the weight from the checkpoint virtual void load_state_dict(const StateDict& state_dict) { @@ -147,9 +64,6 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { private: DecoderType decoder_layer_{nullptr}; -#if defined(USE_NPU) - layer::BlockCopy block_copy_{nullptr}; -#endif }; template @@ -165,11 +79,7 @@ class LlmModelImplBase : public torch::nn::Module { } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { -#if defined(USE_NPU) - return embed_tokens_(input_ids, 0); -#else return embed_tokens_(input_ids); -#endif } // tokens: [num_tokens] @@ -188,127 +98,23 @@ class LlmModelImplBase : public torch::nn::Module { if (inputs_embeds.defined()) { h = inputs_embeds; } else { -#if defined(USE_NPU) - h = embed_tokens_(tokens, 0); -#else h = embed_tokens_(tokens); -#endif - } - -#if defined(USE_NPU) - auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); -#else - auto target_cos_sin = cos_sin_.index({positions}); -#endif - auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = target_cos_sin_chunks[0].contiguous(); - auto sin_pos = target_cos_sin_chunks[1].contiguous(); - - if (positions.dim() == 2) { // mrope - auto apply = [this](torch::Tensor x) { - auto sections = mrope_section_; - sections.insert(sections.end(), sections.begin(), sections.end()); - - auto vec = x.split(sections, -1); - std::vector selects; - selects.reserve(vec.size()); - - for (int64_t i = 0; i < vec.size(); ++i) { - auto m = vec[i]; - selects.push_back(m[i % mrope_section_.size()]); - } - return torch::cat(selects, -1); - }; - cos_pos = apply(cos_pos.reshape( - {positions.sizes().front(), -1, cos_pos.sizes().back()})); - sin_pos = apply(sin_pos.reshape( - {positions.sizes().front(), -1, sin_pos.sizes().back()})); - } - - ModelInputParams& input_params_new = - const_cast(input_params); - torch::Tensor attn_mask; - if (model_type_ == "qwen2") { - max_seq_len_ = FLAGS_enable_chunked_prefill - ? std::max(input_params.kv_max_seq_len, max_seq_len_) - : 128; - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - } else { - max_seq_len_ = FLAGS_enable_chunked_prefill - ? std::max(input_params.kv_max_seq_len, max_seq_len_) - : 128; - if (FLAGS_enable_chunked_prefill) { - int num_sequences = input_params.num_sequences; - if (num_sequences > 0) { - std::vector req_mask_vec; - req_mask_vec.reserve(num_sequences); - - for (int j = 0; j < num_sequences; j++) { - auto mask = - attn_mask_.gen_append_mask(input_params.q_seq_lens_vec[j], - input_params.kv_seq_lens_vec[j], - max_seq_len_, - cos_pos.dtype().toScalarType(), - cos_pos.device()); - req_mask_vec.emplace_back(mask); - } - attn_mask = torch::cat(req_mask_vec, 0); - } - } else { - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - } } -#if defined(USE_NPU) - for (size_t i = 0; i < layers_.size(); i++) { - aclrtEvent* event = nullptr; - std::atomic* event_flag = nullptr; - if (input_params.layer_synchronizer != nullptr) { - event = input_params.layer_synchronizer->get_event(i); - event_flag = input_params.layer_synchronizer->get_event_flag(i); - } - if (input_params.layer_wise_load_synchronizer != nullptr) { - if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { - return torch::Tensor(); - } - } - - auto& layer = layers_[i]; - - if (layer_forward_interrupted_) { - VLOG(1) << "Forward interrupted at layer: " << i; - return torch::Tensor(); - } - - layer(h, - cos_pos, - sin_pos, - attn_mask, - kv_caches[i], - input_params_new, - i, - event, - event_flag); - } - return norm_(h, 0); -#else - layer::update_dummy_run_input(dp_rank_, positions, input_params_new); - bool is_prefill = input_params_new.q_max_seq_len > 1; + auto modified_input_params = input_params; + auto position = positions; + layer::update_dummy_run_input(dp_rank_, position, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = - layer::AttentionMetadata::build(input_params_new, is_prefill); - if (positions.dim() == 2) { - attn_metadata.mrope_cos = std::move(cos_pos); - attn_metadata.mrope_sin = std::move(sin_pos); - } + layer::AttentionMetadata::build(modified_input_params, is_prefill); + torch::Tensor h_ret; for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; - h = layer(h, positions, attn_metadata, kv_caches[i], input_params_new); + h_ret = layer( + h, position, attn_metadata, kv_caches[i], modified_input_params); } - return norm_(h); -#endif + return norm_(h_ret); } // load the weight from the checkpoint @@ -324,27 +130,6 @@ class LlmModelImplBase : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } -#if defined(USE_NPU) - virtual void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - - virtual void merge_loaded_weights() { - embed_tokens_->merge_loaded_weights(); - - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->merge_loaded_weights(); - } - norm_->merge_loaded_weights(); - } -#endif - virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; } virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { @@ -352,16 +137,12 @@ class LlmModelImplBase : public torch::nn::Module { } protected: - torch::Tensor cos_sin_; int max_seq_len_ = 0; torch::Tensor cos_pos_; torch::Tensor sin_pos_; int device_id = 0; layer::AttentionMask attn_mask_; int dp_rank_ = 0; -#if defined(USE_NPU) - layer::PosEmbedding atb_pos_emb_{nullptr}; -#endif std::vector mrope_section_; // test @@ -411,15 +192,10 @@ class LlmForCausalLMImplBase : public torch::nn::Module { const torch::Tensor& seleted_idxes) { // select tokens if provided auto h = hidden_states; - // test -#if defined(USE_NPU) - return lm_head_(hidden_states, seleted_idxes, 0); -#else if (seleted_idxes.defined()) { h = h.index_select(/*dim=*/0, seleted_idxes); } return lm_head_(h); -#endif } void load_model(std::unique_ptr loader, @@ -433,15 +209,6 @@ class LlmForCausalLMImplBase : public torch::nn::Module { lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } } -#if defined(USE_NPU) - // verify - model_->verify_loaded_weights(prefix); - lm_head_->verify_loaded_weights("lm_head."); - - model_->merge_loaded_weights(); - // test - lm_head_->merge_loaded_weights(); -#endif } virtual void prepare_expert_weight(int32_t layer_id, @@ -467,7 +234,6 @@ class LlmForCausalLMImplBase : public torch::nn::Module { LlmModelType model_{nullptr}; int device_id = 0; bool tie_word_embeddings{false}; - // test layer::LmHead lm_head_{nullptr}; }; diff --git a/xllm/models/llm/mlu/deepseek_v2.h b/xllm/models/llm/mlu/deepseek_v2.h deleted file mode 100644 index 733d3e312..000000000 --- a/xllm/models/llm/mlu/deepseek_v2.h +++ /dev/null @@ -1,234 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#pragma once - -#include - -#include -#include - -#include "core/layers/deepseek_v2_decoder_layer.h" -#include "models/llm/llm_model_base.h" - -// DeepSeek v2 compatible with huggingface weights -// ref to: -// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py - -namespace xllm { - -using torch::indexing::None; -using ISlice = torch::indexing::Slice; - -class DeepseekV2DecoderLayerImpl : public torch::nn::Module { - public: - DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) { - // register submodules - decoder_layer_ = register_module("decoder_layer", - layer::DeepseekV2DecoderLayer(context, i)); - } - - torch::Tensor forward(torch::Tensor& x, - torch::Tensor& positions, - const layer::AttentionMetadata& attn_metadata, - KVCache& kv_cache, - const ModelInputParams& input_params) { - return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); - } - - void load_state_dict(const StateDict& state_dict) { - decoder_layer_->load_state_dict(state_dict); - } - - virtual void prepare_expert_weight(int32_t layer_id, - const std::vector& expert_ids) { - return; - } - virtual void update_expert_weight(int32_t layer_id) { return; } - - private: - layer::DeepseekV2DecoderLayer decoder_layer_{nullptr}; -}; -TORCH_MODULE(DeepseekV2DecoderLayer); - -class DeepseekV2ModelImpl : public torch::nn::Module { - public: - DeepseekV2ModelImpl(const ModelContext& context) { - auto options = context.get_tensor_options(); - auto model_args = context.get_model_args(); - auto parallel_args = context.get_parallel_args(); - - blocks_ = register_module("layers", torch::nn::ModuleList()); - layers_.reserve(model_args.n_layers()); - - // register submodules - num_speculative_tokens_ = model_args.num_speculative_tokens(); - - // MTP is not support for now - if (num_speculative_tokens_ > 0) { - LOG(FATAL) << "DeepSeek MTP on MLU is not support for now"; - } - - embed_tokens_ = - register_module("embed_tokens", - layer::WordEmbedding(model_args.vocab_size(), - model_args.hidden_size(), - context.get_parallel_args(), - options)); - norm_ = register_module( - "norm", - layer::RmsNorm( - model_args.hidden_size(), model_args.rms_norm_eps(), options)); - - // create decoder layers - for (int32_t i = 0; i < model_args.n_layers(); ++i) { - auto block = DeepseekV2DecoderLayer(context, i); - layers_.push_back(block); - blocks_->push_back(block); - } - - dp_size_ = parallel_args.dp_size(); - std::vector indices; - dp_local_tp_size_ = parallel_args.world_size() / dp_size_; - dp_rank_ = parallel_args.rank() / dp_local_tp_size_; - rank_ = parallel_args.rank(); - for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { - indices.push_back(i); - } - } - - torch::Tensor forward_native(torch::Tensor tokens, - torch::Tensor positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - bool is_prefill = input_params.q_max_seq_len > 1; - auto attn_metadata = - layer::AttentionMetadata::build(input_params, is_prefill); - torch::Tensor h = embed_tokens_(tokens); - for (size_t i = 0; i < layers_.size(); i++) { - auto& layer = layers_[i]; - h = layer(h, positions, attn_metadata, kv_caches[i], input_params); - } - return norm_(h); - } - - // Provide batched signature to satisfy callers that pass vectors - torch::Tensor forward(const torch::Tensor& tokens, - const torch::Tensor& positions, - std::vector& kv_caches, - const ModelInputParams& input_params) { - return forward_native(tokens, positions, kv_caches, input_params); - } - - // load the weight from the checkpoint - void load_state_dict(const StateDict& state_dict) { - embed_tokens_->load_state_dict( - state_dict.get_dict_with_prefix("embed_tokens.")); - // call each layer's load_state_dict function - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->load_state_dict( - state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); - } - norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); - } - - layer::WordEmbedding get_word_embedding() { return embed_tokens_; } - - void set_word_embedding(layer::WordEmbedding& word_embedding) { - embed_tokens_ = word_embedding; - } - - private: - torch::nn::ModuleList blocks_{nullptr}; - std::vector layers_; - int32_t dp_rank_; - int32_t rank_; - int32_t dp_size_; - int32_t dp_local_tp_size_; - int32_t num_speculative_tokens_ = 0; - layer::WordEmbedding embed_tokens_{nullptr}; - layer::RmsNorm norm_{nullptr}; -}; -TORCH_MODULE(DeepseekV2Model); - -class DeepseekV2ForCausalLMImpl - : public LlmForCausalLMImplBase { - public: - DeepseekV2ForCausalLMImpl(const ModelContext& context) - : LlmForCausalLMImplBase(context) {} -}; -TORCH_MODULE(DeepseekV2ForCausalLM); - -// register the causal model -REGISTER_CAUSAL_MODEL(deepseek_v2, DeepseekV2ForCausalLM); - -// register the model args -// example config: -// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json -REGISTER_MODEL_ARGS(deepseek_v2, [&] { - LOAD_ARG_OR(model_type, "model_type", "deepseek_v2"); - LOAD_ARG_OR(dtype, "torch_dtype", ""); - LOAD_ARG_OR(vocab_size, "vocab_size", 102400); - LOAD_ARG_OR(hidden_size, "hidden_size", 2048); - LOAD_ARG_OR(n_layers, "num_hidden_layers", 27); - LOAD_ARG_OR(n_heads, "num_attention_heads", 16); - LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 16); - LOAD_ARG_OR(intermediate_size, "intermediate_size", 10944); - LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 163840); - LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); - LOAD_ARG_OR(eos_token_id, "eos_token_id", 100001); - LOAD_ARG_OR(bos_token_id, "bos_token_id", 100000); - LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); - LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); - LOAD_ARG_OR(sliding_window, "sliding_window", 4096); - LOAD_ARG_OR(max_window_layers, "max_window_layers", 27); - - LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 1); - LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); - LOAD_ARG_OR(topk_method, "topk_method", "greedy"); - LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 64); - LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 2); - LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 6); - LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 1408); - LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 1.0f); - LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", false); - LOAD_ARG_OR(n_group, "n_group", 1); - LOAD_ARG_OR(topk_group, "topk_group", 1); - LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); - LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); - LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); - LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 0); - LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); - - LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { - return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); - }); - LOAD_ARG_OR_FUNC( - rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); }); - - SET_ARG(rope_scaling_rope_type, "deepseek_yarn"); - LOAD_ARG(rope_scaling_beta_fast, "rope_scaling.beta_fast"); - LOAD_ARG(rope_scaling_beta_slow, "rope_scaling.beta_slow"); - LOAD_ARG(rope_scaling_factor, "rope_scaling.factor"); - LOAD_ARG_OR( - rope_extrapolation_factor, "rope_scaling.extrapolation_factor", 1.0f); - LOAD_ARG(rope_scaling_mscale, "rope_scaling.mscale"); - LOAD_ARG(rope_scaling_mscale_all_dim, "rope_scaling.mscale_all_dim"); - LOAD_ARG(rope_scaling_original_max_position_embeddings, - "rope_scaling.original_max_position_embeddings"); - LOAD_ARG_OR(rope_scaling_attn_factor, "rope_scaling.attn_factor", 1.0f); - - SET_ARG(stop_token_ids, std::unordered_set({100001})); -}); -} // namespace xllm diff --git a/xllm/models/llm/npu/deepseek_v2.h b/xllm/models/llm/npu/deepseek_v2.h new file mode 100644 index 000000000..5d85f8af8 --- /dev/null +++ b/xllm/models/llm/npu/deepseek_v2.h @@ -0,0 +1,403 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include +#include +#include + +#include "core/common/global_flags.h" +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model/npu_dp_ep_padding.h" +#include "core/framework/model_context.h" +#include "core/layers/attention_mask.h" +#include "core/layers/deepseek_v2_decoder_layer.h" +#include "core/layers/lm_head.h" +#include "core/layers/pos_embedding.h" +#include "core/layers/rms_norm.h" +#include "core/layers/rotary_embedding.h" +#include "core/layers/word_embedding.h" +#include "models/model_registry.h" +// DeepSeek v2 compatible with huggingface weights +// ref to: +// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py + +namespace xllm { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +class DeepseekV2DecoderLayerImpl : public torch::nn::Module { + public: + DeepseekV2DecoderLayerImpl(const ModelContext& context, + const int32_t i, + const float sm_scale) { + // register submodules + decoder_layer_ = register_module( + "decoder_layer", layer::DeepseekV2DecoderLayer(context, i, sm_scale)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + const ModelInputParams& input_params, + aclrtEvent* event, + std::atomic* event_flag) { + return decoder_layer_(x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + event, + event_flag); + } + + void load_state_dict(const StateDict& state_dict) { + decoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + decoder_layer_->verify_loaded_weights(prefix); + } + + void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); } + + void prepare_expert_weight(const std::vector& expert_list) { + decoder_layer_->prepare_expert_weight(expert_list); + } + + void update_expert_weight() { decoder_layer_->update_expert_weight(); } + + private: + layer::DeepseekV2DecoderLayer decoder_layer_{nullptr}; +}; +TORCH_MODULE(DeepseekV2DecoderLayer); + +class DeepseekV2ModelImpl : public torch::nn::Module { + public: + DeepseekV2ModelImpl(const ModelContext& context) + : device_(context.get_tensor_options().device()) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + // register submodules + device_ = options.device(); + dtype_ = options.dtype().toScalarType(); + num_speculative_tokens_ = model_args.num_speculative_tokens(); + + // rotary positional embedding + auto inv_freq = rotary::apply_deepseek_yarn_rope_scaling( + model_args.rope_scaling_factor(), + model_args.rope_extrapolation_factor(), + model_args.rope_scaling_beta_fast(), + model_args.rope_scaling_beta_slow(), + model_args.rotary_dim(), + model_args.rope_theta(), + model_args.rope_scaling_original_max_position_embeddings()); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + float sm_scale = 1.0f; + pos_emb_ = create_rotary_embedding(model_args, + model_args.rotary_dim(), + inv_freq, + /*interleaved=*/false, + sm_scale, + options); + atb_pos_emb_ = layer::PosEmbedding(context); + + max_seq_len_ = model_args.max_position_embeddings(); + int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + + for (int32_t i = 0; i < model_args.n_layers(); ++i) { + auto block = DeepseekV2DecoderLayer(context, i, sm_scale); + layers_.push_back(block); + blocks_->push_back(block); + } + + norm_ = register_module("norm", layer::RmsNorm(context)); + // dp_size_=4; + dp_size_ = parallel_args.dp_size(); + std::vector indices; + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + dp_rank_ = parallel_args.rank() / dp_local_tp_size_; + rank_ = parallel_args.rank(); + mapping_data_ = parallel_args.mapping_data(); + num_experts_per_tok_ = model_args.num_experts_per_tok(); + for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { + indices.push_back(i); + } + } + + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); + } + } + + auto h = embed_tokens_(tokens, 0); + auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0); + auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = cos_sin_chunks[0].contiguous(); + auto sin_pos = cos_sin_chunks[1].contiguous(); + + torch::Tensor attn_mask; + if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { + attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); + } else { + attn_mask = attn_mask_.gen_free_mask( + num_speculative_tokens_ + 1, dtype_, device_); + } + + for (size_t i = 0; i < layers_.size(); i++) { + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + + auto& layer = layers_[i]; + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + input_params, + event, + event_flag); + } + return norm_(h, 0); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + + "."); + } + norm_->verify_loaded_weights(prefix + "norm."); + } + + void merge_loaded_weights() { + embed_tokens_->merge_loaded_weights(); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->merge_loaded_weights(); + } + norm_->merge_loaded_weights(); + } + + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + layers_[layer_id]->prepare_expert_weight(expert_ids); + } + + void update_expert_weight(int32_t layer_id) { + layers_[layer_id]->update_expert_weight(); + } + + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + int32_t max_seq_len_ = 0; + int32_t dp_rank_; + int32_t rank_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + nlohmann::json mapping_data_; + int32_t num_experts_per_tok_; + int32_t num_speculative_tokens_ = 0; + at::Device device_; + torch::Dtype dtype_; + layer::WordEmbedding embed_tokens_{nullptr}; + std::shared_ptr pos_emb_{nullptr}; + layer::PosEmbedding atb_pos_emb_{nullptr}; + layer::AttentionMask attn_mask_; + layer::RmsNorm norm_{nullptr}; +}; +TORCH_MODULE(DeepseekV2Model); + +class DeepseekV2ForCausalLMImpl : public torch::nn::Module { + public: + DeepseekV2ForCausalLMImpl(const ModelContext& context) { + model_ = register_module("model", DeepseekV2Model(context)); + lm_head_ = register_module("lm_head", layer::LmHead(context)); + first_k_dense_replace_ = context.get_model_args().first_k_dense_replace(); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return lm_head_(hidden_states, seleted_idxes, 0); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix("model.")); + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + } + + // verify + model_->verify_loaded_weights("model."); + lm_head_->verify_loaded_weights("lm_head."); + + model_->merge_loaded_weights(); + lm_head_->merge_loaded_weights(); + } + + void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + model_->prepare_expert_weight(layer_id + first_k_dense_replace_, + expert_ids); + } + + void update_expert_weight(int32_t layer_id) { + model_->update_expert_weight(layer_id + first_k_dense_replace_); + } + + layer::LmHead get_lm_head() { return lm_head_; } + + void set_lm_head(layer::LmHead& head) { lm_head_ = head; } + + layer::WordEmbedding get_word_embedding() { + return model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + model_->set_word_embedding(word_embedding); + } + + private: + DeepseekV2Model model_{nullptr}; + layer::LmHead lm_head_{nullptr}; + int32_t first_k_dense_replace_; +}; +TORCH_MODULE(DeepseekV2ForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(deepseek_v2, DeepseekV2ForCausalLM); + +// register the model args +// example config: +// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json +REGISTER_MODEL_ARGS(deepseek_v2, [&] { + LOAD_ARG_OR(model_type, "model_type", "deepseek_v2"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 102400); + LOAD_ARG_OR(hidden_size, "hidden_size", 2048); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 27); + LOAD_ARG_OR(n_heads, "num_attention_heads", 16); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 16); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 10944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 163840); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 100001); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 100000); + LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(sliding_window, "sliding_window", 4096); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 27); + + LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 1); + LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); + LOAD_ARG_OR(topk_method, "topk_method", "greedy"); + LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 64); + LOAD_ARG_OR(n_shared_experts, "n_shared_experts", 2); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 6); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 1408); + LOAD_ARG_OR(routed_scaling_factor, "routed_scaling_factor", 1.0f); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", false); + LOAD_ARG_OR(n_group, "n_group", 1); + LOAD_ARG_OR(topk_group, "topk_group", 1); + LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); + LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); + LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); + LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 0); + LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim(); + }); + LOAD_ARG_OR_FUNC( + rotary_dim, "rotary_dim", [&] { return args->qk_rope_head_dim(); }); + + SET_ARG(rope_scaling_rope_type, "deepseek_yarn"); + LOAD_ARG(rope_scaling_beta_fast, "rope_scaling.beta_fast"); + LOAD_ARG(rope_scaling_beta_slow, "rope_scaling.beta_slow"); + LOAD_ARG(rope_scaling_factor, "rope_scaling.factor"); + LOAD_ARG_OR( + rope_extrapolation_factor, "rope_scaling.extrapolation_factor", 1.0f); + LOAD_ARG(rope_scaling_mscale, "rope_scaling.mscale"); + LOAD_ARG(rope_scaling_mscale_all_dim, "rope_scaling.mscale_all_dim"); + LOAD_ARG(rope_scaling_original_max_position_embeddings, + "rope_scaling.original_max_position_embeddings"); + LOAD_ARG_OR(rope_scaling_attn_factor, "rope_scaling.attn_factor", 1.0f); + + SET_ARG(stop_token_ids, std::unordered_set({100001})); +}); +} // namespace xllm diff --git a/xllm/models/llm/deepseek_v2_mtp.h b/xllm/models/llm/npu/deepseek_v2_mtp.h similarity index 100% rename from xllm/models/llm/deepseek_v2_mtp.h rename to xllm/models/llm/npu/deepseek_v2_mtp.h diff --git a/xllm/models/llm/mlu/deepseek_v3.h b/xllm/models/llm/npu/deepseek_v3.h similarity index 97% rename from xllm/models/llm/mlu/deepseek_v3.h rename to xllm/models/llm/npu/deepseek_v3.h index c96830f4c..922bf1dfd 100644 --- a/xllm/models/llm/mlu/deepseek_v3.h +++ b/xllm/models/llm/npu/deepseek_v3.h @@ -42,7 +42,6 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] { LOAD_ARG_OR(max_window_layers, "max_window_layers", 61); LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 3); - LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1); LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc"); LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256); @@ -53,7 +52,6 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] { LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); LOAD_ARG_OR(n_group, "n_group", 8); LOAD_ARG_OR(topk_group, "topk_group", 4); - LOAD_ARG_OR(scoring_func, "scoring_func", "sigmoid"); LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128); LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64); LOAD_ARG_OR(v_head_dim, "v_head_dim", 128); diff --git a/xllm/models/llm/embedding_model_base.h b/xllm/models/llm/npu/embedding_model_base.h similarity index 97% rename from xllm/models/llm/embedding_model_base.h rename to xllm/models/llm/npu/embedding_model_base.h index 058ab2545..68351a4b4 100644 --- a/xllm/models/llm/embedding_model_base.h +++ b/xllm/models/llm/npu/embedding_model_base.h @@ -15,7 +15,7 @@ limitations under the License. #pragma once -#include "models/llm/llm_model_base.h" +#include "llm_model_base.h" namespace xllm { @@ -60,11 +60,8 @@ class LlmForEmbeddingImplBase : public torch::nn::Module { } model_->load_state_dict(sub_dict); } -#if defined(USE_NPU) - // verify model_->verify_loaded_weights(prefix + "model."); model_->merge_loaded_weights(); -#endif } virtual void prepare_expert_weight(int32_t layer_id, diff --git a/xllm/models/llm/glm4_moe.h b/xllm/models/llm/npu/glm4_moe.h similarity index 100% rename from xllm/models/llm/glm4_moe.h rename to xllm/models/llm/npu/glm4_moe.h diff --git a/xllm/models/llm/glm4_moe_mtp.h b/xllm/models/llm/npu/glm4_moe_mtp.h similarity index 100% rename from xllm/models/llm/glm4_moe_mtp.h rename to xllm/models/llm/npu/glm4_moe_mtp.h diff --git a/xllm/models/llm/kimi_k2.h b/xllm/models/llm/npu/kimi_k2.h similarity index 100% rename from xllm/models/llm/kimi_k2.h rename to xllm/models/llm/npu/kimi_k2.h diff --git a/xllm/models/llm/llama.h b/xllm/models/llm/npu/llama.h similarity index 100% rename from xllm/models/llm/llama.h rename to xllm/models/llm/npu/llama.h diff --git a/xllm/models/llm/llama3.h b/xllm/models/llm/npu/llama3.h similarity index 100% rename from xllm/models/llm/llama3.h rename to xllm/models/llm/npu/llama3.h diff --git a/xllm/models/llm/npu/llm_model_base.h b/xllm/models/llm/npu/llm_model_base.h new file mode 100644 index 000000000..d79d9d133 --- /dev/null +++ b/xllm/models/llm/npu/llm_model_base.h @@ -0,0 +1,402 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "core/common/global_flags.h" +#include "core/common/interruption_bus.h" +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model_context.h" +#include "core/layers/attention_mask.h" +#include "core/layers/block_copy.h" +#include "core/layers/lm_head.h" +#include "core/layers/pos_embedding.h" +#include "core/layers/rms_norm.h" +#include "models/model_registry.h" +#include "xllm_kernels/core/include/atb_speed/log.h" + +namespace xllm { + +torch::Tensor get_concat_rotary_embedding(int64_t dim, + int64_t seq_len, + double rope_theta, + const torch::TensorOptions& options) { + auto options_new = + torch::device(options.device()).dtype(at::ScalarType::Double); + auto inv_freq = + 1.0 / torch::pow(rope_theta, torch::arange(0, dim, 2, options_new) / dim) + .to(at::ScalarType::Float); + auto seq_idx = torch::arange(seq_len, options_new); + + auto freqs = torch::ger(seq_idx, inv_freq).to(torch::kFloat32); + auto emb = torch::cat({freqs, freqs}, -1); + auto rope_cos = torch::cos(emb); + auto rope_sin = torch::sin(emb); + + auto dtype = options.dtype(); + if (dtype == torch::kFloat16 || dtype == torch::kBFloat16 || + dtype == torch::kInt8) { + if (dtype == torch::kBFloat16) { + rope_cos = rope_cos.to(torch::kBFloat16); + rope_sin = rope_sin.to(torch::kBFloat16); + } else { + rope_cos = rope_cos.to(torch::kFloat16); + rope_sin = rope_sin.to(torch::kFloat16); + } + } + std::vector cos_sin{rope_cos, rope_sin}; + return torch::cat(cos_sin, -1); +} + +template +class LlmDecoderLayerImplBase : public torch::nn::Module { + public: + LlmDecoderLayerImplBase(const ModelContext& context) { + // register submodules + decoder_layer_ = register_module("decoder_layer", DecoderType(context)); + block_copy_ = register_module("block_copy", layer::BlockCopy(context)); + } + + virtual torch::Tensor forward(torch::Tensor& x, + torch::Tensor& cos_pos, + torch::Tensor& sin_pos, + torch::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + int node_id, + aclrtEvent* event, + std::atomic* event_flag) { + if (input_params.src_block_indices.numel() > 0) { + block_copy_(kv_cache.get_k_cache(), + kv_cache.get_v_cache(), + input_params.src_block_indices, + input_params.dst_block_indices, + input_params.cum_sum, + 0); + } + + return decoder_layer_(x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + event, + event_flag, + node_id); + } + + virtual void verify_loaded_weights(const std::string& prefix) const { + decoder_layer_->verify_loaded_weights(); + } + virtual void merge_loaded_weights() { + decoder_layer_->merge_loaded_weights(); + block_copy_->merge_loaded_weights(); + } + + // load the weight from the checkpoint + virtual void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + decoder_layer_->load_state_dict(state_dict); + } + + private: + DecoderType decoder_layer_{nullptr}; + layer::BlockCopy block_copy_{nullptr}; +}; + +template +class LlmModelImplBase : public torch::nn::Module { + public: + // mode type: qwen2, qwen3 .etc + LlmModelImplBase(const std::string& model_type, const ModelArgs& args) + : model_type_(model_type) { + InterruptionBus::get_instance().subscribe([this](bool interrupted) { + this->layer_forward_interrupted_ = interrupted; + }); + mrope_section_ = args.rope_scaling_mrope_section(); + } + + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return embed_tokens_(input_ids, 0); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + // test + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens, 0); + } + + auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + auto sections = mrope_section_; + sections.insert(sections.end(), sections.begin(), sections.end()); + + auto vec = x.split(sections, -1); + std::vector selects; + selects.reserve(vec.size()); + + for (int64_t i = 0; i < vec.size(); ++i) { + auto m = vec[i]; + selects.push_back(m[i % mrope_section_.size()]); + } + return torch::cat(selects, -1); + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } + + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor attn_mask; + if (model_type_ == "qwen2") { + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input_params.kv_max_seq_len, max_seq_len_) + : 128; + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + } else { + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input_params.kv_max_seq_len, max_seq_len_) + : 128; + if (FLAGS_enable_chunked_prefill) { + int num_sequences = input_params.num_sequences; + if (num_sequences > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(num_sequences); + + for (int j = 0; j < num_sequences; j++) { + auto mask = + attn_mask_.gen_append_mask(input_params.q_seq_lens_vec[j], + input_params.kv_seq_lens_vec[j], + max_seq_len_, + cos_pos.dtype().toScalarType(), + cos_pos.device()); + req_mask_vec.emplace_back(mask); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } else { + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + } + } + + for (size_t i = 0; i < layers_.size(); i++) { + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + + auto& layer = layers_[i]; + + if (layer_forward_interrupted_) { + VLOG(1) << "Forward interrupted at layer: " << i; + return torch::Tensor(); + } + + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + input_params_new, + i, + event, + event_flag); + } + return norm_(h, 0); + } + + // load the weight from the checkpoint + virtual void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + virtual void verify_loaded_weights(const std::string& prefix) const { + embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + + "."); + } + norm_->verify_loaded_weights(prefix + "norm."); + } + + virtual void merge_loaded_weights() { + embed_tokens_->merge_loaded_weights(); + + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->merge_loaded_weights(); + } + norm_->merge_loaded_weights(); + } + + virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + + protected: + torch::Tensor cos_sin_; + int max_seq_len_ = 0; + torch::Tensor cos_pos_; + torch::Tensor sin_pos_; + int device_id = 0; + layer::AttentionMask attn_mask_; + int dp_rank_ = 0; + layer::PosEmbedding atb_pos_emb_{nullptr}; + + std::vector mrope_section_; + // test + // ParallelEmbedding embed_tokens_{nullptr}; + layer::WordEmbedding embed_tokens_{nullptr}; + layer::RmsNorm norm_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + // hold same data but different type as blocks_ to avoid type cast + std::vector layers_; + + bool layer_forward_interrupted_ = false; + + private: + std::string model_type_; +}; + +template +class LlmForCausalLMImplBase : public torch::nn::Module { + public: + LlmForCausalLMImplBase(const ModelContext& context) { + tie_word_embeddings = context.get_model_args().tie_word_embeddings(); + // register submodules + model_ = register_module("model", LlmModelType(context)); + + lm_head_ = register_module("lm_head", layer::LmHead(context)); + } + + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return model_->get_input_embeddings(input_ids); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + virtual torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + // select tokens if provided + auto h = hidden_states; + return lm_head_(hidden_states, seleted_idxes, 0); + } + + void load_model(std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + if (tie_word_embeddings) { + lm_head_->load_state_dict( + state_dict->get_dict_with_prefix(prefix + "embed_tokens.")); + } else { + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + } + } + model_->verify_loaded_weights(prefix); + lm_head_->verify_loaded_weights("lm_head."); + + model_->merge_loaded_weights(); + lm_head_->merge_loaded_weights(); + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + virtual layer::LmHead get_lm_head() { return lm_head_; } + + virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; } + + virtual layer::WordEmbedding get_word_embedding() { + return model_->get_word_embedding(); + } + + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { + model_->set_word_embedding(word_embedding); + } + + protected: + // parameter members, must be registered + LlmModelType model_{nullptr}; + int device_id = 0; + bool tie_word_embeddings{false}; + layer::LmHead lm_head_{nullptr}; +}; + +} // namespace xllm diff --git a/xllm/models/llm/npu/qwen2.h b/xllm/models/llm/npu/qwen2.h new file mode 100644 index 000000000..fe3cae507 --- /dev/null +++ b/xllm/models/llm/npu/qwen2.h @@ -0,0 +1,116 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/layers/qwen2_decoder_layer.h" +#include "llm_model_base.h" + +// QWen2 model compatible with huggingface weights +// ref to: +// https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py +namespace xllm { + +class QWen2DecoderLayerImpl + : public LlmDecoderLayerImplBase { + public: + QWen2DecoderLayerImpl(const ModelContext& context) + : LlmDecoderLayerImplBase(context) {} +}; +TORCH_MODULE(QWen2DecoderLayer); + +class QWen2ModelImpl : public LlmModelImplBase { + public: + QWen2ModelImpl(const ModelContext& context) + : LlmModelImplBase("qwen2", context.get_model_args()) { + // register submodules + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + norm_ = register_module("norm", layer::RmsNorm(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + atb_pos_emb_ = layer::PosEmbedding(context); + cos_sin_ = get_concat_rotary_embedding( + model_args.hidden_size() / model_args.n_heads(), + model_args.max_position_embeddings(), + model_args.rope_theta(), + options); + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + + for (int32_t i = 0; i < model_args.n_layers(); i++) { + auto block = QWen2DecoderLayer(context); + layers_.push_back(block); + blocks_->push_back(block); + } + } +}; +TORCH_MODULE(QWen2Model); + +class QWen2ForCausalLMImpl : public LlmForCausalLMImplBase { + public: + QWen2ForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} +}; +TORCH_MODULE(QWen2ForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen2, QWen2ForCausalLM); + +// register the model args +// example config: +// https://huggingface.co/Qwen/Qwen2-7B-Instruct/blob/main/config.json +REGISTER_MODEL_ARGS(qwen2, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen2"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); + LOAD_ARG_OR(hidden_size, "hidden_size", 3584); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); + LOAD_ARG_OR(n_heads, "num_attention_heads", 28); + LOAD_ARG(n_kv_heads, "num_key_value_heads"); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(attention_bias, "attention_bias", true); + // LOAD_ARG_OR(no_bias, "no_bias", true); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151643); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + + // For Qwen2/2.5 model < 7B, tie_word_embeddings = true + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(sliding_window, "sliding_window", 4096); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); + +} // namespace xllm diff --git a/xllm/models/llm/npu/qwen3.h b/xllm/models/llm/npu/qwen3.h new file mode 100644 index 000000000..9c65cc2ed --- /dev/null +++ b/xllm/models/llm/npu/qwen3.h @@ -0,0 +1,238 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/layers/qwen3_decoder_layer.h" +#include "llm_model_base.h" + +namespace xllm { + +class QWen3DecoderLayerImpl + : public LlmDecoderLayerImplBase { + public: + QWen3DecoderLayerImpl(const ModelContext& context) + : LlmDecoderLayerImplBase(context) {} +}; +TORCH_MODULE(QWen3DecoderLayer); + +class QWen3ModelImpl : public LlmModelImplBase { + public: + QWen3ModelImpl(const ModelContext& context) + : LlmModelImplBase("qwen3", context.get_model_args()) { + // register submodules + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + norm_ = register_module("norm", layer::RmsNorm(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + atb_pos_emb_ = layer::PosEmbedding(context); + cos_sin_ = get_concat_rotary_embedding(128, + model_args.max_position_embeddings(), + model_args.rope_theta(), + options); + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + // encode_attn_mask_ = + // layer::AttentionMask(options.device(), + // options.dtype()).get_attn_mask(2048, options.device(), + // options.dtype()); + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + + for (int32_t i = 0; i < model_args.n_layers(); i++) { + auto block = QWen3DecoderLayer(context); + layers_.push_back(block); + blocks_->push_back(block); + } + } + + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + torch::Tensor visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + auto local_this = selected + visual_embeds; + hidden_states.index_put_({visual_pos_masks}, local_this); + return hidden_states; + } + + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + bool use_deepstack = input_params.deep_stacks.size() > 0; + ModelInputParams& input_params_new = + const_cast(input_params); + std::vector deep_stacks; + + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens, 0); + } + if (use_deepstack) { + deep_stacks = input_params.deep_stacks; // [num_deepstack, hidden_size] + } + + auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + auto freqs_t = x[0].clone(); + for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { + int64_t offset = dim_idx; + int64_t section_len = mrope_section_[dim_idx]; + int64_t length = section_len * 3; + auto idx_first_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_second_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_tensor = + torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); + // freqs_t[..., idx] = freqs[dim_idx][..., idx] + auto src = x[dim_idx].index_select(-1, idx_tensor); + freqs_t.index_copy_(-1, idx_tensor, src); + } + return freqs_t; + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } + + torch::Tensor attn_mask; + + torch::Tensor max_of_seq = torch::max(input_params.kv_seq_lens); + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(max_of_seq.item(), max_seq_len_) + : 128; + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + + if (FLAGS_enable_chunked_prefill) { + int batch_size = input_params.q_seq_lens_vec.size(); + if (batch_size > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(batch_size); + + for (int j = 0; j < batch_size; j++) { + int start = + input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j]; + int end = input_params.kv_seq_lens_vec[j]; + + auto req_mask_slice = attn_mask.slice(0, start, end); + req_mask_vec.emplace_back(req_mask_slice); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } + + for (size_t i = 0; i < layers_.size(); i++) { + aclrtEvent* event{nullptr}; + std::atomic* event_flag{nullptr}; + + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + + auto& layer = layers_[i]; + + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + input_params_new, + i, + event, + event_flag); + if (use_deepstack) { + if (deep_stacks.size() > 0 && i < deep_stacks.size()) { + h = deepstack_process( + h, input_params.visual_pos_masks, deep_stacks[i]); + } + } + } + return norm_(h, 0); + } + + private: + torch::Tensor viusal_pos_mask_; +}; +TORCH_MODULE(QWen3Model); + +class QWen3ForCausalLMImpl : public LlmForCausalLMImplBase { + public: + QWen3ForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} +}; +TORCH_MODULE(QWen3ForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen3, QWen3ForCausalLM); + +// register the model args +REGISTER_MODEL_ARGS(qwen3, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen3"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); + LOAD_ARG_OR(hidden_size, "hidden_size", 3584); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); + LOAD_ARG_OR(n_heads, "num_attention_heads", 28); + LOAD_ARG(n_kv_heads, "num_key_value_heads"); + // LOAD_ARG_OR(no_bias, "no_bias", true); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151643); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + + // For qwen3/2.5 model < 7B, tie_word_embeddings = true + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); + +} // namespace xllm diff --git a/xllm/models/llm/qwen3_embedding.h b/xllm/models/llm/npu/qwen3_embedding.h similarity index 98% rename from xllm/models/llm/qwen3_embedding.h rename to xllm/models/llm/npu/qwen3_embedding.h index 57e1604a6..7ca804c46 100644 --- a/xllm/models/llm/qwen3_embedding.h +++ b/xllm/models/llm/npu/qwen3_embedding.h @@ -3,7 +3,7 @@ #include #include "core/framework/model/embedding_lm.h" -#include "models/llm/embedding_model_base.h" +#include "embedding_model_base.h" #include "qwen3.h" namespace xllm { diff --git a/xllm/models/llm/npu/qwen3_moe.h b/xllm/models/llm/npu/qwen3_moe.h new file mode 100644 index 000000000..d24972fe6 --- /dev/null +++ b/xllm/models/llm/npu/qwen3_moe.h @@ -0,0 +1,460 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +// #include +#include "core/framework/model/npu_dp_ep_padding.h" +#include "core/framework/model_context.h" +#include "core/layers/common/layer_utils.h" +#include "core/layers/qwen3_moe_decoder_layer.h" +#include "llm_model_base.h" + +namespace xllm { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { + public: + Qwen3MoeDecoderLayerImpl(const ModelContext& context, const int32_t i) { + // register submodules + decoder_layer_ = register_module("decoder_layer", + layer::Qwen3MoeDecoderLayer(context, i)); + } + + torch::Tensor forward(torch::Tensor x, + torch::Tensor cos_pos, + torch::Tensor sin_pos, + torch::Tensor attn_mask, + KVCache& kv_cache, + const ModelInputParams& input_params, + torch::Tensor expert_array, + aclrtEvent* event = nullptr, + std::atomic* event_flag = nullptr) { + return decoder_layer_(x, + cos_pos, + sin_pos, + attn_mask, + kv_cache, + input_params, + expert_array, + event, + event_flag); + } + + void load_state_dict(const StateDict& state_dict) { + auto experts_state_dict = state_dict.get_dict_with_prefix("mlp.experts."); + auto fused_gate_up = experts_state_dict.get_tensor("gate_up_proj"); + auto fused_down = experts_state_dict.get_tensor("down_proj"); + + bool is_fused = fused_gate_up.defined() && fused_down.defined(); + + if (is_fused) { + torch::Tensor expert_gate_up = fused_gate_up; + torch::Tensor expert_down = fused_down; + + const int num_experts = expert_gate_up.size(0); + + auto chunks = expert_gate_up.chunk(2, /*dim=*/-1); + auto expert_gate = chunks[0].contiguous(); + auto expert_up = chunks[1].contiguous(); + + std::unordered_map out_state_dict; + for (const auto& [name, tensor] : state_dict) { + if (name.find("self_attn.") == 0 || name.find("mlp.gate.") == 0 || + name.find("input_layernorm.") == 0 || + name.find("post_attention_layernorm.") == 0) { + out_state_dict.emplace(name, tensor); + } + } + + for (int i = 0; i < num_experts; ++i) { + auto gate_i = expert_gate[i].transpose(0, 1); + auto up_i = expert_up[i].transpose(0, 1); + auto down_i = expert_down[i].transpose(0, 1); + + const std::string base = "mlp.experts." + std::to_string(i) + "."; + out_state_dict.emplace(base + "gate_proj.weight", gate_i); + out_state_dict.emplace(base + "up_proj.weight", up_i); + out_state_dict.emplace(base + "down_proj.weight", down_i); + } + decoder_layer_->load_state_dict(StateDict(std::move(out_state_dict))); + } else { + decoder_layer_->load_state_dict(state_dict); + } + } + + void verify_loaded_weights(const std::string& prefix) const { + decoder_layer_->verify_loaded_weights(prefix); + } + + void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); } + + private: + layer::Qwen3MoeDecoderLayer decoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen3MoeDecoderLayer); + +torch::Tensor get_qwen3_moe_rotary_embedding( + int64_t dim, + int64_t seq_len, + double rope_theta, + const torch::TensorOptions& options) { + return get_concat_rotary_embedding(dim, seq_len, rope_theta, options); +} + +class Qwen3MoeModelImpl : public torch::nn::Module { + public: + Qwen3MoeModelImpl(const ModelContext& context) + : device_(context.get_tensor_options().device()) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + mrope_section_ = model_args.rope_scaling_mrope_section(); + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + // register submodules + device_ = options.device(); + dtype_ = options.dtype().toScalarType(); + num_speculative_tokens_ = model_args.num_speculative_tokens(); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + + cos_sin_ = + get_qwen3_moe_rotary_embedding(128, + model_args.max_position_embeddings(), + model_args.rope_theta(), + options); + + atb_pos_emb_ = layer::PosEmbedding(context); + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + norm_ = register_module("norm", layer::RmsNorm(context)); + mapping_data_ = parallel_args.mapping_data(); + + for (int32_t i = 0; i < model_args.n_layers(); ++i) { + auto block = Qwen3MoeDecoderLayer(context, i); + layers_.push_back(block); + blocks_->push_back(block); + } + + dp_size_ = parallel_args.dp_size(); + std::vector indices; + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + dp_rank_ = parallel_args.rank() / dp_local_tp_size_; + rank_ = parallel_args.rank(); + num_experts_per_tok_ = model_args.num_experts_per_tok(); + for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { + indices.push_back(i); + } + } + + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + torch::Tensor visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + auto local_this = selected + visual_embeds; + hidden_states.index_put_({visual_pos_masks}, local_this); + return hidden_states; + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); + } + } + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens, 0); + } + + auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); + auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); + auto cos_pos = target_cos_sin_chunks[0].contiguous(); + auto sin_pos = target_cos_sin_chunks[1].contiguous(); + if (positions.dim() == 2) { // mrope + auto apply = [this](torch::Tensor x) { + // auto sections = mrope_section_; + auto freqs_t = x[0].clone(); + for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { + int64_t offset = dim_idx; // H -> offset=1, W -> offset=2 + int64_t section_len = mrope_section_[dim_idx]; + int64_t length = section_len * 3; + + // indices: [offset, offset+3, offset+6, ..., < length] + auto idx_first_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_second_half = torch::arange(offset, length, 3, torch::kLong); + auto idx_tensor = + torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); + // freqs_t[..., idx] = freqs[dim_idx][..., idx] + auto src = x[dim_idx].index_select(-1, idx_tensor); + freqs_t.index_copy_(-1, idx_tensor, src); + } + return freqs_t; + }; + cos_pos = apply(cos_pos.reshape( + {positions.sizes().front(), -1, cos_pos.sizes().back()})); + sin_pos = apply(sin_pos.reshape( + {positions.sizes().front(), -1, sin_pos.sizes().back()})); + } + + torch::Tensor attn_mask; + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input_params.kv_max_seq_len, max_seq_len_) + : 128; + if (FLAGS_enable_chunked_prefill) { + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + + int batch_size = input_params.q_seq_lens_vec.size(); + if (batch_size > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(batch_size); + + for (int j = 0; j < batch_size; j++) { + int start = + input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j]; + int end = input_params.kv_seq_lens_vec[j]; + + auto req_mask_slice = attn_mask.slice(0, start, end); + req_mask_vec.emplace_back(req_mask_slice); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } else if (input_params.global_empty_kv_cache) { + attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); + } + auto deep_stacks = input_params.deep_stacks; + int deep_stack_size = deep_stacks.size(); + + int64_t input_length = h.size(0); + torch::Tensor expert_array = torch::arange( + 0, + input_length * num_experts_per_tok_, + torch::TensorOptions().dtype(torch::kInt32).device(tokens.device())); + for (size_t i = 0; i < layers_.size(); i++) { + aclrtEvent* event = nullptr; + std::atomic* event_flag = nullptr; + if (input_params.layer_synchronizer != nullptr) { + event = input_params.layer_synchronizer->get_event(i); + event_flag = input_params.layer_synchronizer->get_event_flag(i); + } + if (input_params.layer_wise_load_synchronizer != nullptr) { + if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { + return torch::Tensor(); + } + } + + auto& layer = layers_[i]; + layer(h, + cos_pos, + sin_pos, + attn_mask, + kv_caches[i], + input_params, + expert_array, + event, + event_flag); + if (deep_stack_size && i < deep_stack_size) { + h = deepstack_process(h, input_params.visual_pos_masks, deep_stacks[i]); + } + } + return norm_(h, 0); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + + "."); + } + norm_->verify_loaded_weights(prefix + "norm."); + } + + void merge_loaded_weights() { + embed_tokens_->merge_loaded_weights(); + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->merge_loaded_weights(); + } + norm_->merge_loaded_weights(); + } + + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return embed_tokens_(input_ids, 0); + } + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + int32_t max_seq_len_ = 0; + int32_t dp_rank_; + int32_t rank_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + nlohmann::json mapping_data_; + int32_t num_experts_per_tok_; + int32_t num_speculative_tokens_ = 0; + at::Device device_; + torch::Dtype dtype_; + layer::WordEmbedding embed_tokens_{nullptr}; + layer::AttentionMask attn_mask_; + layer::RmsNorm norm_{nullptr}; + torch::Tensor cos_sin_; + layer::PosEmbedding atb_pos_emb_{nullptr}; + std::vector mrope_section_; +}; +TORCH_MODULE(Qwen3MoeModel); + +class Qwen3MoeForCausalLMImpl : public torch::nn::Module { + public: + Qwen3MoeForCausalLMImpl(const ModelContext& context) { + model_ = register_module("model", Qwen3MoeModel(context)); + lm_head_ = register_module("lm_head", layer::LmHead(context)); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return lm_head_(hidden_states, seleted_idxes, 0); + } + + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return model_->get_input_embeddings(input_ids); + } + + void load_model(std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + } + + model_->verify_loaded_weights(prefix); + lm_head_->verify_loaded_weights("lm_head."); + + model_->merge_loaded_weights(); + lm_head_->merge_loaded_weights(); + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + layer::LmHead get_lm_head() { return lm_head_; } + + void set_lm_head(layer::LmHead& head) { lm_head_ = head; } + + layer::WordEmbedding get_word_embedding() { + return model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + model_->set_word_embedding(word_embedding); + } + + private: + Qwen3MoeModel model_{nullptr}; + layer::LmHead lm_head_{nullptr}; +}; +TORCH_MODULE(Qwen3MoeForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen3_moe, Qwen3MoeForCausalLM); + +// register the model args +// example config: +// https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/config.json +// https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json +REGISTER_MODEL_ARGS(qwen3_moe, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen3_moe"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(attention_bias, "attention_bias", false); + LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0f); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 151643); + LOAD_ARG_OR(decoder_sparse_step, "decoder_sparse_step", 1); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151645); + LOAD_ARG_OR(head_dim, "head_dim", 128); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "hidden_size", 2048); + LOAD_ARG_OR(initializer_range, "initializer_range", 0.02f); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 6144); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 40960); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 48); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 768); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(num_experts, "num_experts", 128); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 48); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 4); + LOAD_ARG_OR(output_router_logits, "output_router_logits", false); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + LOAD_ARG_OR(router_aux_loss_coef, "router_aux_loss_coef", 0.001f); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(vocab_size, "vocab_size", 151936); + LOAD_ARG_OR(mlp_only_layers, "mlp_only_layers", std::vector()); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); +} // namespace xllm diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index af5bfacf4..4d56229a6 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -49,14 +49,6 @@ class QWen2ModelImpl : public LlmModelImplBase { norm_ = register_module("norm", layer::RmsNorm(context)); embed_tokens_ = register_module("embed_tokens", layer::WordEmbedding(context)); -#if defined(USE_NPU) - atb_pos_emb_ = layer::PosEmbedding(context); -#endif - cos_sin_ = get_concat_rotary_embedding( - model_args.hidden_size() / model_args.n_heads(), - model_args.max_position_embeddings(), - model_args.rope_theta(), - options); int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; attn_mask_ = layer::AttentionMask(options.device(), options.dtype().toScalarType(), diff --git a/xllm/models/llm/qwen3.h b/xllm/models/llm/qwen3.h index 5e41dd09e..4decdc100 100644 --- a/xllm/models/llm/qwen3.h +++ b/xllm/models/llm/qwen3.h @@ -45,21 +45,6 @@ class QWen3ModelImpl : public LlmModelImplBase { norm_ = register_module("norm", layer::RmsNorm(context)); embed_tokens_ = register_module("embed_tokens", layer::WordEmbedding(context)); -#if defined(USE_NPU) - atb_pos_emb_ = layer::PosEmbedding(context); -#endif - cos_sin_ = get_concat_rotary_embedding(128, - model_args.max_position_embeddings(), - model_args.rope_theta(), - options); - int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; - // encode_attn_mask_ = - // layer::AttentionMask(options.device(), - // options.dtype()).get_attn_mask(2048, options.device(), - // options.dtype()); - attn_mask_ = layer::AttentionMask(options.device(), - options.dtype().toScalarType(), - /*mask_value=*/mask_value); for (int32_t i = 0; i < model_args.n_layers(); i++) { auto block = QWen3DecoderLayer(context); @@ -96,130 +81,23 @@ class QWen3ModelImpl : public LlmModelImplBase { if (inputs_embeds.defined()) { h = inputs_embeds; } else { -#if defined(USE_NPU) - h = embed_tokens_(tokens, 0); -#else h = embed_tokens_(tokens); -#endif } - if (use_deepstack) { - deep_stacks = input_params.deep_stacks; // [num_deepstack, hidden_size] - } -#if defined(USE_NPU) - auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); -#else - auto target_cos_sin = cos_sin_.index({positions}); -#endif - auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = target_cos_sin_chunks[0].contiguous(); - auto sin_pos = target_cos_sin_chunks[1].contiguous(); - - if (positions.dim() == 2) { // mrope - auto apply = [this](torch::Tensor x) { - auto freqs_t = x[0].clone(); - for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { - int64_t offset = dim_idx; - int64_t section_len = mrope_section_[dim_idx]; - int64_t length = section_len * 3; - auto idx_first_half = torch::arange(offset, length, 3, torch::kLong); - auto idx_second_half = torch::arange(offset, length, 3, torch::kLong); - auto idx_tensor = - torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); - // freqs_t[..., idx] = freqs[dim_idx][..., idx] - auto src = x[dim_idx].index_select(-1, idx_tensor); - freqs_t.index_copy_(-1, idx_tensor, src); - } - return freqs_t; - }; - cos_pos = apply(cos_pos.reshape( - {positions.sizes().front(), -1, cos_pos.sizes().back()})); - sin_pos = apply(sin_pos.reshape( - {positions.sizes().front(), -1, sin_pos.sizes().back()})); - } - - torch::Tensor attn_mask; - - torch::Tensor max_of_seq = torch::max(input_params.kv_seq_lens); - max_seq_len_ = FLAGS_enable_chunked_prefill - ? std::max(max_of_seq.item(), max_seq_len_) - : 128; - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - - if (FLAGS_enable_chunked_prefill) { - int batch_size = input_params.q_seq_lens_vec.size(); - if (batch_size > 0) { - std::vector req_mask_vec; - req_mask_vec.reserve(batch_size); - - for (int j = 0; j < batch_size; j++) { - int start = - input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j]; - int end = input_params.kv_seq_lens_vec[j]; - - auto req_mask_slice = attn_mask.slice(0, start, end); - req_mask_vec.emplace_back(req_mask_slice); - } - attn_mask = torch::cat(req_mask_vec, 0); - } - } - -#if defined(USE_NPU) - for (size_t i = 0; i < layers_.size(); i++) { - aclrtEvent* event{nullptr}; - std::atomic* event_flag{nullptr}; - - if (input_params.layer_synchronizer != nullptr) { - event = input_params.layer_synchronizer->get_event(i); - event_flag = input_params.layer_synchronizer->get_event_flag(i); - } - if (input_params.layer_wise_load_synchronizer != nullptr) { - if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { - return torch::Tensor(); - } - } - auto& layer = layers_[i]; - - layer(h, - cos_pos, - sin_pos, - attn_mask, - kv_caches[i], - input_params_new, - i, - event, - event_flag); - if (use_deepstack) { - if (deep_stacks.size() > 0 && i < deep_stacks.size()) { - h = deepstack_process( - h, input_params.visual_pos_masks, deep_stacks[i]); - } - } - } - return norm_(h, 0); -#else - layer::update_dummy_run_input(dp_rank_, positions, input_params_new); - bool is_prefill = input_params_new.q_max_seq_len > 1; + auto modified_input_params = input_params; + auto position = positions; + layer::update_dummy_run_input(dp_rank_, position, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = - layer::AttentionMetadata::build(input_params_new, is_prefill); - if (positions.dim() == 2) { - attn_metadata.mrope_cos = std::move(cos_pos); - attn_metadata.mrope_sin = std::move(sin_pos); - } + layer::AttentionMetadata::build(modified_input_params, is_prefill); + torch::Tensor h_ret; for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; - h = layer(h, positions, attn_metadata, kv_caches[i], input_params_new); - if (use_deepstack) { - if (deep_stacks.size() > 0 && i < deep_stacks.size()) { - h = deepstack_process( - h, input_params.visual_pos_masks, deep_stacks[i]); - } - } + h_ret = layer( + h, positions, attn_metadata, kv_caches[i], modified_input_params); } - return norm_(h); -#endif + return norm_(h_ret); } private: diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 8ae069047..ed59f73c0 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -17,10 +17,6 @@ limitations under the License. #include -#include -#if defined(USE_NPU) -#include "core/framework/model/npu_dp_ep_padding.h" -#endif #include "core/framework/model_context.h" #include "core/layers/common/layer_utils.h" #include "core/layers/qwen3_moe_decoder_layer.h" @@ -39,27 +35,6 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { layer::Qwen3MoeDecoderLayer(context, i)); } -#if defined(USE_NPU) - torch::Tensor forward(torch::Tensor x, - torch::Tensor cos_pos, - torch::Tensor sin_pos, - torch::Tensor attn_mask, - KVCache& kv_cache, - const ModelInputParams& input_params, - torch::Tensor expert_array, - aclrtEvent* event = nullptr, - std::atomic* event_flag = nullptr) { - return decoder_layer_(x, - cos_pos, - sin_pos, - attn_mask, - kv_cache, - input_params, - expert_array, - event, - event_flag); - } -#else torch::Tensor forward(torch::Tensor& x, torch::Tensor& positions, const layer::AttentionMetadata& attn_metadata, @@ -67,7 +42,7 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { const ModelInputParams& input_params) { return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); } -#endif + void load_state_dict(const StateDict& state_dict) { auto experts_state_dict = state_dict.get_dict_with_prefix("mlp.experts."); auto fused_gate_up = experts_state_dict.get_tensor("gate_up_proj"); @@ -110,27 +85,11 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { } } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - decoder_layer_->verify_loaded_weights(prefix); - } - - void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); } -#endif - private: layer::Qwen3MoeDecoderLayer decoder_layer_{nullptr}; }; TORCH_MODULE(Qwen3MoeDecoderLayer); -torch::Tensor get_qwen3_moe_rotary_embedding( - int64_t dim, - int64_t seq_len, - double rope_theta, - const torch::TensorOptions& options) { - return get_concat_rotary_embedding(dim, seq_len, rope_theta, options); -} - class Qwen3MoeModelImpl : public torch::nn::Module { public: Qwen3MoeModelImpl(const ModelContext& context) @@ -148,19 +107,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { embed_tokens_ = register_module("embed_tokens", layer::WordEmbedding(context)); - cos_sin_ = - get_qwen3_moe_rotary_embedding(128, - model_args.max_position_embeddings(), - model_args.rope_theta(), - options); - -#if defined(USE_NPU) - atb_pos_emb_ = layer::PosEmbedding(context); - int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; - attn_mask_ = layer::AttentionMask(options.device(), - options.dtype().toScalarType(), - /*mask_value=*/mask_value); -#endif + max_seq_len_ = model_args.max_position_embeddings(); norm_ = register_module("norm", layer::RmsNorm(context)); mapping_data_ = parallel_args.mapping_data(); @@ -203,134 +150,19 @@ class Qwen3MoeModelImpl : public torch::nn::Module { positions = torch::tensor({0}).to(torch::kInt32).to(device_); } } - auto inputs_embeds = input_params.input_embedding; - torch::Tensor h; - if (inputs_embeds.defined()) { - h = inputs_embeds; - } else { -#if defined(USE_NPU) - h = embed_tokens_(tokens, 0); -#else - h = embed_tokens_(tokens); -#endif - } -#if defined(USE_NPU) - auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); -#else - auto target_cos_sin = cos_sin_.index({positions}); -#endif - auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); - auto cos_pos = target_cos_sin_chunks[0].contiguous(); - auto sin_pos = target_cos_sin_chunks[1].contiguous(); - if (positions.dim() == 2) { // mrope - auto apply = [this](torch::Tensor x) { - // auto sections = mrope_section_; - auto freqs_t = x[0].clone(); - for (int dim_idx = 1; dim_idx <= 2; ++dim_idx) { - int64_t offset = dim_idx; // H -> offset=1, W -> offset=2 - int64_t section_len = mrope_section_[dim_idx]; - int64_t length = section_len * 3; - - // indices: [offset, offset+3, offset+6, ..., < length] - auto idx_first_half = torch::arange(offset, length, 3, torch::kLong); - auto idx_second_half = torch::arange(offset, length, 3, torch::kLong); - auto idx_tensor = - torch::cat({idx_first_half, idx_second_half}, 0).to(x.device()); - // freqs_t[..., idx] = freqs[dim_idx][..., idx] - auto src = x[dim_idx].index_select(-1, idx_tensor); - freqs_t.index_copy_(-1, idx_tensor, src); - } - return freqs_t; - }; - cos_pos = apply(cos_pos.reshape( - {positions.sizes().front(), -1, cos_pos.sizes().back()})); - sin_pos = apply(sin_pos.reshape( - {positions.sizes().front(), -1, sin_pos.sizes().back()})); - } - - torch::Tensor attn_mask; - max_seq_len_ = FLAGS_enable_chunked_prefill - ? std::max(input_params.kv_max_seq_len, max_seq_len_) - : 128; - if (FLAGS_enable_chunked_prefill) { - attn_mask = attn_mask_.get_attn_mask( - max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); - - int batch_size = input_params.q_seq_lens_vec.size(); - if (batch_size > 0) { - std::vector req_mask_vec; - req_mask_vec.reserve(batch_size); - - for (int j = 0; j < batch_size; j++) { - int start = - input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j]; - int end = input_params.kv_seq_lens_vec[j]; - - auto req_mask_slice = attn_mask.slice(0, start, end); - req_mask_vec.emplace_back(req_mask_slice); - } - attn_mask = torch::cat(req_mask_vec, 0); - } - } else if (input_params.global_empty_kv_cache) { - attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); - } - auto deep_stacks = input_params.deep_stacks; - int deep_stack_size = deep_stacks.size(); -#if defined(USE_NPU) - int64_t input_length = h.size(0); - torch::Tensor expert_array = torch::arange( - 0, - input_length * num_experts_per_tok_, - torch::TensorOptions().dtype(torch::kInt32).device(tokens.device())); - for (size_t i = 0; i < layers_.size(); i++) { - aclrtEvent* event = nullptr; - std::atomic* event_flag = nullptr; - if (input_params.layer_synchronizer != nullptr) { - event = input_params.layer_synchronizer->get_event(i); - event_flag = input_params.layer_synchronizer->get_event_flag(i); - } - if (input_params.layer_wise_load_synchronizer != nullptr) { - if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) { - return torch::Tensor(); - } - } - - auto& layer = layers_[i]; - layer(h, - cos_pos, - sin_pos, - attn_mask, - kv_caches[i], - input_params, - expert_array, - event, - event_flag); - if (deep_stack_size && i < deep_stack_size) { - h = deepstack_process(h, input_params.visual_pos_masks, deep_stacks[i]); - } - } - return norm_(h, 0); -#else ModelInputParams modified_input_params = input_params; layer::update_dummy_run_input(dp_rank_, positions, modified_input_params); bool is_prefill = modified_input_params.q_max_seq_len > 1; auto attn_metadata = layer::AttentionMetadata::build(modified_input_params, is_prefill); - if (positions.dim() == 2) { - attn_metadata.mrope_cos = std::move(cos_pos); - attn_metadata.mrope_sin = std::move(sin_pos); - } + torch::Tensor h = embed_tokens_(tokens); for (size_t i = 0; i < layers_.size(); i++) { auto& layer = layers_[i]; h = layer( h, positions, attn_metadata, kv_caches[i], modified_input_params); - if (deep_stack_size && i < deep_stack_size) { - h = deepstack_process(h, input_params.visual_pos_masks, deep_stacks[i]); - } } return norm_(h); -#endif } // load the weight from the checkpoint @@ -345,38 +177,13 @@ class Qwen3MoeModelImpl : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) + - "."); - } - norm_->verify_loaded_weights(prefix + "norm."); - } - - void merge_loaded_weights() { - embed_tokens_->merge_loaded_weights(); - for (int i = 0; i < layers_.size(); i++) { - layers_[i]->merge_loaded_weights(); - } - norm_->merge_loaded_weights(); - } -#endif - layer::WordEmbedding get_word_embedding() { return embed_tokens_; } void set_word_embedding(layer::WordEmbedding& word_embedding) { embed_tokens_ = word_embedding; } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { -#if defined(USE_NPU) - return embed_tokens_(input_ids, 0); -#elif defined(USE_MLU) return embed_tokens_(input_ids); -#else - LOG(FATAL) << "Backend not supported: enable USE_NPU or USE_MLU."; -#endif } private: @@ -395,10 +202,6 @@ class Qwen3MoeModelImpl : public torch::nn::Module { layer::WordEmbedding embed_tokens_{nullptr}; layer::AttentionMask attn_mask_; layer::RmsNorm norm_{nullptr}; - torch::Tensor cos_sin_; -#if defined(USE_NPU) - layer::PosEmbedding atb_pos_emb_{nullptr}; -#endif std::vector mrope_section_; }; TORCH_MODULE(Qwen3MoeModel); @@ -425,16 +228,12 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { // returns: [num_tokens, vocab_size] torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { -#if defined(USE_NPU) - return lm_head_(hidden_states, seleted_idxes, 0); -#else // select tokens if provided auto h = hidden_states; if (seleted_idxes.defined()) { h = h.index_select(/*dim=*/0, seleted_idxes); } return lm_head_(h); -#endif } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { @@ -447,15 +246,6 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } - -#if defined(USE_NPU) - // verify - model_->verify_loaded_weights(prefix); - lm_head_->verify_loaded_weights("lm_head."); - - model_->merge_loaded_weights(); - lm_head_->merge_loaded_weights(); -#endif } virtual void prepare_expert_weight(int32_t layer_id, diff --git a/xllm/models/models.h b/xllm/models/models.h index da0338353..b0d76860d 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -15,6 +15,40 @@ limitations under the License. #pragma once +#if defined(USE_NPU) +#include "dit/pipeline_flux.h" // IWYU pragma: keep +#include "dit/pipeline_flux_control.h" // IWYU pragma: keep +#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep +#include "llm/npu/deepseek_v2.h" // IWYU pragma: keep +#include "llm/npu/deepseek_v2_mtp.h" // IWYU pragma: keep +#include "llm/npu/deepseek_v3.h" // IWYU pragma: keep +#include "llm/npu/embedding_model_base.h" // IWYU pragma: keep +#include "llm/npu/glm4_moe.h" // IWYU pragma: keep +#include "llm/npu/glm4_moe_mtp.h" // IWYU pragma: keep +#include "llm/npu/kimi_k2.h" // IWYU pragma: keep +#include "llm/npu/llama.h" // IWYU pragma: keep +#include "llm/npu/llama3.h" // IWYU pragma: keep +#include "llm/npu/llm_model_base.h" // IWYU pragma: keep +#include "llm/npu/qwen2.h" // IWYU pragma: keep +#include "llm/npu/qwen3.h" // IWYU pragma: keep +#include "llm/npu/qwen3_embedding.h" // IWYU pragma: keep +#include "llm/npu/qwen3_moe.h" // IWYU pragma: keep +#include "vlm/npu/minicpmv.h" // IWYU pragma: keep +#include "vlm/npu/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/npu/qwen3_vl.h" // IWYU pragma: keep +#include "vlm/npu/qwen3_vl_moe.h" // IWYU pragma: keep +#elif defined(USE_MLU) +#include "llm/deepseek_v2.h" // IWYU pragma: keep +#include "llm/deepseek_v3.h" // IWYU pragma: keep +#include "llm/deepseek_v32.h" // IWYU pragma: keep +#include "llm/llm_model_base.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3.h" // IWYU pragma: keep +#include "llm/qwen3_moe.h" // IWYU pragma: keep +#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl.h" // IWYU pragma: keep +#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep +#elif defined(USE_CUDA) #include "llm/llm_model_base.h" // IWYU pragma: keep #include "llm/qwen2.h" // IWYU pragma: keep #include "llm/qwen3.h" // IWYU pragma: keep @@ -22,23 +56,4 @@ limitations under the License. #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep - -#if defined(USE_NPU) -#include "dit/pipeline_flux.h" // IWYU pragma: keep -#include "dit/pipeline_flux_control.h" // IWYU pragma: keep -#include "dit/pipeline_flux_fill.h" // IWYU pragma: keep -#include "llm/deepseek_v2.h" // IWYU pragma: keep -#include "llm/deepseek_v2_mtp.h" // IWYU pragma: keep -#include "llm/deepseek_v3.h" // IWYU pragma: keep -#include "llm/glm4_moe.h" // IWYU pragma: keep -#include "llm/glm4_moe_mtp.h" // IWYU pragma: keep -#include "llm/kimi_k2.h" // IWYU pragma: keep -#include "llm/llama.h" // IWYU pragma: keep -#include "llm/llama3.h" // IWYU pragma: keep -#include "llm/qwen3_embedding.h" // IWYU pragma: keep -#include "vlm/minicpmv.h" // IWYU pragma: keep -#elif defined(USE_MLU) -#include "llm/mlu/deepseek_v2.h" // IWYU pragma: keep -#include "llm/mlu/deepseek_v3.h" // IWYU pragma: keep -#include "llm/mlu/deepseek_v32.h" // IWYU pragma: keep #endif diff --git a/xllm/models/vlm/minicpmv.h b/xllm/models/vlm/npu/minicpmv.h similarity index 100% rename from xllm/models/vlm/minicpmv.h rename to xllm/models/vlm/npu/minicpmv.h diff --git a/xllm/models/vlm/npu/qwen2_5_vl.h b/xllm/models/vlm/npu/qwen2_5_vl.h new file mode 100644 index 000000000..8c4d609f3 --- /dev/null +++ b/xllm/models/vlm/npu/qwen2_5_vl.h @@ -0,0 +1,836 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/layers/lm_head.h" +#include "core/layers/qwen2_decoder_layer.h" +#include "core/layers/qwen2dot5_vision_decode_layer.h" +#include "core/layers/rms_norm.h" +#include "models/llm/npu/qwen2.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/qwen2_vl_image_processor.h" +#include "xllm_kernels/core/include/atb_speed/log.h" + +namespace xllm { + +#define PrintTensor(tensor) print_tensor(tensor, #tensor, 10, true, false); + +class Qwen2_5_VLInputProcessor : public InputProcessor { + enum class TokenType { + INVALID, + IMAGE, + VIDEO, + }; + + public: + Qwen2_5_VLInputProcessor(const ModelArgs& args) { + merge_size_ = args.mm_image_merge_size(); + } + + void process(std::string& prompt, const MMData& mm_data) override { + torch::Tensor image_grid_thw; + if (auto res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + torch::Tensor video_grid_thw; + if (auto res = mm_data.get("video_grid_thw")) + video_grid_thw = res.value(); + + if (!image_grid_thw.defined() && !video_grid_thw.defined()) return; + + auto merge_length = merge_size_ * merge_size_; + int total_image_token = 0; + if (image_grid_thw.defined()) { + auto count = image_grid_thw.sizes()[0]; + for (int idx = 0; idx < count; ++idx) + total_image_token += + image_grid_thw[idx].prod().item() / merge_length; + } + + int total_video_token = 0; + if (video_grid_thw.defined()) { + auto count = video_grid_thw.sizes()[0]; + for (int idx = 0; idx < count; ++idx) + total_video_token += + video_grid_thw[idx].prod().item() / merge_length; + } + + size_t total_token_len = total_image_token * image_token_.size() + + total_video_token * video_token_.size(); + std::string data; + data.reserve(prompt.size() + total_token_len); + + int image_index = 0; + int video_index = 0; + + const torch::Tensor* grid_thw = nullptr; + const std::string* token = nullptr; + int* index = 0; + + size_t begin = 0; + auto pair = find_vision_token(prompt, begin); + + while (pair.second != std::string::npos) { + data.append(prompt, begin, pair.second - begin); + + if (pair.first == TokenType::IMAGE) { + grid_thw = &image_grid_thw; + token = &image_token_; + index = &image_index; + } else if (pair.first == TokenType::VIDEO) { + grid_thw = &video_grid_thw; + token = &video_token_; + index = &video_index; + } else { + assert(false); + } + + auto token_num = (*grid_thw)[(*index)].prod().item() / merge_length; + while (token_num--) data.append(*token); + + ++(*index); + begin = pair.second + token->size(); + pair = find_vision_token(prompt, begin); + } + + if (begin < prompt.size()) data.append(prompt, begin, std::string::npos); + + prompt = std::move(data); + } + + private: + std::pair find_vision_token(const std::string& prompt, + size_t begin) { + auto img_pos = prompt.find(image_token_, begin); + auto vid_pos = prompt.find(video_token_, begin); + + if (img_pos == std::string::npos && vid_pos == std::string::npos) + return {TokenType::INVALID, std::string::npos}; + else if (vid_pos == std::string::npos) + return {TokenType::IMAGE, img_pos}; + else if (img_pos == std::string::npos) + return {TokenType::VIDEO, vid_pos}; + else + return img_pos < vid_pos ? std::make_pair(TokenType::IMAGE, img_pos) + : std::make_pair(TokenType::VIDEO, vid_pos); + } + + private: + const std::string image_token_ = "<|image_pad|>"; + const std::string video_token_ = "<|video_pad|>"; + + int merge_size_ = 0; +}; + +class Qwen2_5_VisionBlockImpl : public torch::nn::Module { + public: + Qwen2_5_VisionBlockImpl(const ModelContext& context) { + // register submodules + encoder_layer_ = register_module( + "encoder_layer", layer::Qwen2dot5VisionEncoderLayer(context)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id) { + return encoder_layer_(x, + m_cos_pos, + m_sin_pos, + cu_seq_len, + cu_seq_len_vec, + input_params, + node_id); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + encoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + encoder_layer_->verify_loaded_weights(); + } + void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } + + private: + layer::Qwen2dot5VisionEncoderLayer encoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen2_5_VisionBlock); + +class Qwen2_5_VisionPatchEmbedImpl : public torch::nn::Module { + public: + Qwen2_5_VisionPatchEmbedImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + auto in_features = model_args.mm_num_channels() * + model_args.mm_temporal_patch_size() * + model_args.mm_patch_size() * model_args.mm_patch_size(); + + auto out_features = model_args.mm_hidden_size(); + + proj_ = register_module( + "proj", + torch::nn::Linear( + torch::nn::LinearOptions(in_features, out_features).bias(false))); + + proj_->weight.set_data(proj_->weight.to(options)); + } + + torch::Tensor forward(torch::Tensor x) { return proj_(x); } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("proj.weight"); + if (weight.defined()) { + weight = weight.reshape({weight.size(0), -1}); + DCHECK_EQ(proj_->weight.sizes(), weight.sizes()) + << "proj weight size mismatch for " << name(); + proj_->weight.data().copy_(weight); + proj_weight_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(proj_weight_loaded_) + << "weight is not loaded for " << prefix + "proj.weight"; + } + + private: + bool proj_weight_loaded_ = false; + torch::nn::Linear proj_{nullptr}; +}; +TORCH_MODULE(Qwen2_5_VisionPatchEmbed); + +class Qwen2_5_VisionRotaryEmbeddingImpl : public torch::nn::Module { + public: + Qwen2_5_VisionRotaryEmbeddingImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + dim_ = model_args.mm_head_dim() / 2; + theta_ = 10000.0; + + auto opts = options.dtype(torch::kFloat32); + auto inv_freq = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, opts) / dim_); + inv_freq_ = register_buffer("inv_freq", inv_freq); + } + + void update_freqs_cache(int64_t seqlen) { + if (seqlen <= seq_len_cached_) return; + + seqlen *= 2; + seq_len_cached_ = seqlen; + + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .device(inv_freq_.device()); + inv_freq_ = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, options) / dim_); + auto seq = torch::arange(seqlen, options); + freqs_cached_ = torch::outer(seq, inv_freq_); + } + + torch::Tensor forward(int seqlen) { + update_freqs_cache(seqlen); + return freqs_cached_.slice(0, 0, seqlen); + } + + private: + int dim_ = 0; + double theta_ = 0.0; + + int64_t seq_len_cached_ = 0; + torch::Tensor inv_freq_; + torch::Tensor freqs_cached_; +}; +TORCH_MODULE(Qwen2_5_VisionRotaryEmbedding); + +class Qwen2_5_VisionPatchMergerImpl : public torch::nn::Module { + public: + Qwen2_5_VisionPatchMergerImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto quant_args = context.get_quant_args(); + auto parallel_args = context.get_parallel_args(); + + int64_t d_model = model_args.mm_projection_dim(); // out_hidden_size + int context_dim = model_args.mm_hidden_size(); + int spatial_merge_size = model_args.mm_spatial_merge_size(); + + hidden_size_ = + context_dim * static_cast(std::pow(spatial_merge_size, 2)); + + ln_q_ = register_module("ln_q", layer::RmsNorm(context)); + + auto cpl = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, hidden_size_).bias(true)); + cpl->weight.set_data(cpl->weight.to(options)); + cpl->bias.set_data(cpl->bias.to(options)); + auto act = torch::nn::GELU(); + auto rpl = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, d_model).bias(true)); + rpl->weight.set_data(rpl->weight.to(options)); + rpl->bias.set_data(rpl->bias.to(options)); + mlp_ = register_module("mlp", torch::nn::Sequential(cpl, act, rpl)); + layers_ = std::make_tuple(cpl, act, rpl); + } + + torch::Tensor forward(torch::Tensor x) { + x = ln_q_(x, 0); + x = x.view({-1, hidden_size_}); + return mlp_->forward(x); + } + + void load_state_dict(const StateDict& state_dict) { + ln_q_->load_state_dict(state_dict.get_dict_with_prefix("ln_q.")); + + const auto& cpl_dict = state_dict.get_dict_with_prefix("mlp.0."); + const auto& cpl_weight = cpl_dict.get_tensor("weight"); + if (cpl_weight.defined()) { + CHECK_EQ(std::get<0>(layers_)->weight.sizes(), cpl_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<0>(layers_)->weight.data().copy_(cpl_weight); + is_cpl_weight_loaded = true; + } + const auto cpl_bias = cpl_dict.get_tensor("bias"); + if (cpl_bias.defined()) { + CHECK_EQ(std::get<0>(layers_)->bias.sizes(), cpl_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<0>(layers_)->bias.data().copy_(cpl_bias); + is_cpl_bias_loaded = true; + } + + const auto& rpl_dict = state_dict.get_dict_with_prefix("mlp.2."); + const auto& rpl_weight = rpl_dict.get_tensor("weight"); + if (rpl_weight.defined()) { + CHECK_EQ(std::get<2>(layers_)->weight.sizes(), rpl_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<2>(layers_)->weight.data().copy_(rpl_weight); + is_rpl_weight_loaded = true; + } + const auto rpl_bias = rpl_dict.get_tensor("bias"); + if (rpl_bias.defined()) { + CHECK_EQ(std::get<2>(layers_)->bias.sizes(), rpl_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<2>(layers_)->bias.data().copy_(rpl_bias); + is_rpl_bias_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + ln_q_->verify_loaded_weights(prefix + "ln_q."); + CHECK(is_cpl_weight_loaded) + << "weight is not loaded for " << prefix + "mlp.0" + ".weight"; + CHECK(is_cpl_bias_loaded) + << "bias is not loaded for " << prefix + "mlp.0" + ".bias"; + CHECK(is_rpl_weight_loaded) + << "weight is not loaded for " << prefix + "mlp.2" + ".weight"; + CHECK(is_rpl_bias_loaded) + << "bias is not loaded for " << prefix + "mlp.2" + ".bias"; + } + + void merge_loaded_weights() { ln_q_->merge_loaded_weights(); } + + private: + int64_t hidden_size_; + + layer::RmsNorm ln_q_{nullptr}; + torch::nn::Sequential mlp_{nullptr}; + std::tuple layers_ = { + nullptr, + nullptr, + nullptr}; + bool is_cpl_weight_loaded = false; + bool is_cpl_bias_loaded = false; + bool is_rpl_weight_loaded = false; + bool is_rpl_bias_loaded = false; +}; +TORCH_MODULE(Qwen2_5_VisionPatchMerger); + +class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { + public: + Qwen2_5_VisionTransformerImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + hidden_size_ = model_args.mm_hidden_size(); + num_heads_ = model_args.mm_num_attention_heads(); + + window_size_ = model_args.mm_window_size(); + patch_size_ = model_args.mm_patch_size(); + spatial_merge_size_ = model_args.mm_spatial_merge_size(); + const auto& block_indexes = model_args.mm_fullatt_block_indexes(); + fullatt_block_indexes_.insert(block_indexes.begin(), block_indexes.end()); + spatial_merge_unit_ = static_cast(std::pow(spatial_merge_size_, 2)); + + patch_embed_ = + register_module("patch_embed", Qwen2_5_VisionPatchEmbed(context)); + rotary_pos_emb_ = register_module("rotary_pos_emb", + Qwen2_5_VisionRotaryEmbedding(context)); + blocks_ = register_module("blocks", torch::nn::ModuleList()); + + for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { + auto block = Qwen2_5_VisionBlock(context); + blocks_->push_back(block); + layers_.push_back(block); + } + merger_ = register_module("merger", Qwen2_5_VisionPatchMerger(context)); + } + + torch::Tensor rot_pos_emb(torch::Tensor grid_thw) { + std::vector pos_ids_vec; + auto count = grid_thw.sizes()[0]; + pos_ids_vec.reserve(count); + + auto grid_thw_cpu = grid_thw.cpu(); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + for (int idx = 0; idx < count; ++idx) { + auto t = grid_thw_cpu[idx][0].item(); + auto h = grid_thw_cpu[idx][1].item(); + auto w = grid_thw_cpu[idx][2].item(); + + auto hpos_ids = torch::arange(h, options).unsqueeze(1).expand({-1, w}); + hpos_ids = hpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + auto wpos_ids = torch::arange(w, options).unsqueeze(0).expand({h, -1}); + wpos_ids = wpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + pos_ids_vec.push_back( + torch::stack({hpos_ids, wpos_ids}, -1).repeat({t, 1})); + } + + auto pos_ids = torch::cat(pos_ids_vec, 0); + auto max_grid_size = + grid_thw + .index({torch::indexing::Slice(), + torch::indexing::Slice(1, torch::indexing::None)}) + .max(); + + auto rotary_pos_emb_full = rotary_pos_emb_(max_grid_size.item()); + auto rotary_pos_emb = rotary_pos_emb_full.index({pos_ids}).flatten(1); + + return rotary_pos_emb; + } + + torch::Tensor get_window_index(torch::Tensor grid_thw, + std::vector& cu_window_seqlens) { + auto count = grid_thw.sizes()[0]; + std::vector window_index; + window_index.reserve(count); + cu_window_seqlens.reserve(count * 128); + cu_window_seqlens.emplace_back(0); + + int window_index_id = 0; + int vit_merger_window_size = + window_size_ / spatial_merge_size_ / patch_size_; + + auto grid_thw_cpu = grid_thw.cpu(); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + for (int idx = 0; idx < count; ++idx) { + auto grid_t = grid_thw_cpu[idx][0].item(); + auto grid_h = grid_thw_cpu[idx][1].item(); + auto grid_w = grid_thw_cpu[idx][2].item(); + + auto llm_grid_h = grid_h / spatial_merge_size_; + auto llm_grid_w = grid_w / spatial_merge_size_; + + auto index = torch::arange(grid_t * llm_grid_h * llm_grid_w, options) + .reshape({grid_t, llm_grid_h, llm_grid_w}); + auto pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size; + auto pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size; + + auto num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size; + auto num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size; + + namespace F = torch::nn::functional; + auto index_padded = F::pad(index, + F::PadFuncOptions({0, pad_w, 0, pad_h}) + .mode(torch::kConstant) + .value(-100)); + index_padded = index_padded.reshape({grid_t, + num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size}); + + index_padded = index_padded.permute({0, 1, 3, 2, 4}) + .reshape({grid_t, + num_windows_h * num_windows_w, + vit_merger_window_size, + vit_merger_window_size}); + + auto index_padded_ne = torch::ne(index_padded, -100); + auto seqlens = index_padded_ne.sum({2, 3}).reshape({-1}); + index_padded = index_padded.reshape({-1}); + auto index_new = + index_padded.masked_select(index_padded_ne.reshape({-1})); + + window_index.push_back(index_new + window_index_id); + auto cu_seqlens_tmp = + (seqlens.cumsum(0, torch::kInt32) * spatial_merge_unit_ + + cu_window_seqlens.back()) + .cpu(); + cu_window_seqlens.insert( + cu_window_seqlens.end(), + cu_seqlens_tmp.data_ptr(), + cu_seqlens_tmp.data_ptr() + cu_seqlens_tmp.numel()); + window_index_id += grid_t * llm_grid_h * llm_grid_w; + } + + return torch::cat(window_index, 0); + } + + torch::Tensor forward(torch::Tensor hidden_states, + torch::Tensor grid_thw, // [batch,thw] + const ModelInputParams& input_params) { + // patchify + // hidden_states = x.to(device=self.device, dtype=self.dtype); + hidden_states = patch_embed_(hidden_states); + // compute position embedding + auto rotary_pos_emb = rot_pos_emb(grid_thw); + + // windows attention + std::vector cu_window_seqlens_vec; + auto window_index = get_window_index(grid_thw, cu_window_seqlens_vec); + torch::TensorOptions options = torch::TensorOptions() + .dtype(torch::kInt32) + .device(hidden_states.device()); + auto cu_window_seqlens = torch::tensor(cu_window_seqlens_vec, options); + cu_window_seqlens = + std::get<0>(torch::unique_consecutive(cu_window_seqlens)); + auto seq_len = hidden_states.sizes()[0]; + hidden_states = hidden_states.reshape( + {seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + hidden_states = hidden_states.index( + {window_index, torch::indexing::Slice(), torch::indexing::Slice()}); + hidden_states = hidden_states.reshape({seq_len, -1}); + + rotary_pos_emb = rotary_pos_emb.reshape( + {seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + rotary_pos_emb = rotary_pos_emb.index( + {window_index, torch::indexing::Slice(), torch::indexing::Slice()}); + rotary_pos_emb = rotary_pos_emb.reshape({seq_len, -1}); + + // compute cu_seqlens + auto cu_seqlens = torch::repeat_interleave( + grid_thw.index({torch::indexing::Slice(), 1}) * + grid_thw.index({torch::indexing::Slice(), 2}), + grid_thw.index({torch::indexing::Slice(), 0})) + .cumsum(0, torch::kInt32); + namespace F = torch::nn::functional; + cu_seqlens = F::pad( + cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); + + m_cos = rotary_pos_emb.cos().type_as(hidden_states); + m_sin = rotary_pos_emb.sin().type_as(hidden_states); + + // transformers + cu_seqlens = torch::diff(cu_seqlens); + cu_window_seqlens = torch::diff(cu_window_seqlens); + + m_cos = torch::nn::functional::pad( + m_cos, torch::nn::functional::PadFuncOptions({0, 24})); + m_sin = torch::nn::functional::pad( + m_sin, torch::nn::functional::PadFuncOptions({0, 24})); + + m_cos = m_cos.repeat({1, 2}); + m_sin = m_sin.repeat({1, 2}); + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); + torch::Tensor cu_window_seqlens_cpu = cu_window_seqlens.cpu(); + std::vector cu_seqlens_vec( + cu_seqlens_cpu.data_ptr(), // full seqlen vec + cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); + std::vector cu_w_seqlens_vec( + cu_window_seqlens_cpu.data_ptr(), // windows seqlen vec + cu_window_seqlens_cpu.data_ptr() + cu_window_seqlens_cpu.numel()); + for (int idx = 0; idx < blocks_->size(); ++idx) { + torch::Tensor cu_seqlens_now; + std::vector cu_seqlens_now_vec; + if (fullatt_block_indexes_.find(idx) != fullatt_block_indexes_.end()) { + cu_seqlens_now = cu_seqlens; + cu_seqlens_now_vec = cu_seqlens_vec; + } else { + cu_seqlens_now = cu_window_seqlens; + cu_seqlens_now_vec = cu_w_seqlens_vec; + } + hidden_states = layers_[idx](hidden_states, + m_cos, + m_sin, + cu_seqlens_now, + cu_seqlens_now_vec, + input_params_new, + idx); + } + // adapter + hidden_states = merger_(hidden_states); + + auto reverse_indices = torch::argsort(window_index); + hidden_states = + hidden_states.index({reverse_indices, torch::indexing::Slice()}); + return hidden_states; + } + + void load_state_dict(const StateDict& state_dict) { + patch_embed_->load_state_dict( + state_dict.get_dict_with_prefix("patch_embed.")); + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->load_state_dict(state_dict.get_dict_with_prefix( + "blocks." + std::to_string(idx) + ".")); + } + + merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); + } + + void verify_loaded_weights(const std::string& prefix) const { + patch_embed_->verify_loaded_weights(prefix + "patch_embed."); + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->verify_loaded_weights(prefix + "blocks." + + std::to_string(idx) + "."); + } + merger_->verify_loaded_weights(prefix + "merger."); + } + + void merge_loaded_weights() { + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->merge_loaded_weights(); + } + merger_->merge_loaded_weights(); + } + + private: + int hidden_size_ = 0; + int num_heads_ = 0; + int window_size_ = 0; + int patch_size_ = 0; + int spatial_merge_size_ = 0; + std::set fullatt_block_indexes_; + int spatial_merge_unit_ = 0; + + Qwen2_5_VisionPatchEmbed patch_embed_{nullptr}; + Qwen2_5_VisionRotaryEmbedding rotary_pos_emb_{nullptr}; + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + Qwen2_5_VisionPatchMerger merger_{nullptr}; + + torch::Tensor m_cos; + torch::Tensor m_sin; + int device_id = 0; +}; +TORCH_MODULE(Qwen2_5_VisionTransformer); + +struct Qwen2_5_VLImageInputs { + torch::Tensor pixel_values; + torch::Tensor image_grid_thw; +}; + +struct Qwen2_5_VLVideoInputs { + torch::Tensor pixel_values_videos; + torch::Tensor video_grid_thw; + torch::Tensor second_per_grid_ts; +}; + +class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { + public: + Qwen2_5_VLForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Qwen2_5_VisionTransformer(context)); + + language_model_ = + register_module("language_model", QWen2ForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + // visual + auto image_embeds = visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); + inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Qwen2_5_VLImageInputs{pixel_values, image_grid_thw}; + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); + } + + visual_->verify_loaded_weights("visual."); + visual_->merge_loaded_weights(); + + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader)); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + + Qwen2_5_VisionTransformer visual_{nullptr}; + QWen2ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen2_5_VLForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(qwen2_5_vl, Qwen2_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(qwen2_5_vl, Qwen2_5_VLForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(qwen2_5_vl, Qwen2VLImageProcessor); + +REGISTER_MODEL_ARGS(qwen2_5_vl, [&] { + // text config + // LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 151643); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151645); + LOAD_ARG_OR(vision_start_token_id, "vision_start_token_id", 151652); + LOAD_ARG_OR(vision_end_token_id, "vision_end_token_id", 151653); + LOAD_ARG_OR(vision_token_id, "vision_token_id", 151654); + LOAD_ARG_OR(image_token_id, "image_token_id", 151655); + LOAD_ARG_OR(video_token_id, "video_token_id", 151656); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "hidden_size", 3584); + // LOAD_ARG_OR(initializer_range, "initializer_range", 0.02); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 128000); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + LOAD_ARG_OR(model_type, "model_type", "qwen2_5_vl"); + LOAD_ARG_OR(n_heads, "num_attention_heads", 28); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 4); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-06); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + LOAD_ARG_OR(sliding_window, "sliding_window", 32768); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + // LOAD_ARG_OR(transformers_version, "transformers_version", "4.41.2"); + // LOAD_ARG_OR(use_cache, "use_cache", true); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + + // vision_config + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 32); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "silu"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1280); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 3420); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_chans", 3); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 3584); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 14); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG_OR(mm_spatial_patch_size, "vision_config.spatial_patch_size", 14); + LOAD_ARG_OR(mm_window_size, "vision_config.window_size", 112); + LOAD_ARG_OR(mm_fullatt_block_indexes, + "vision_config.fullatt_block_indexes", + std::vector({7, 15, 23, 31})); + LOAD_ARG_OR(mm_tokens_per_second, "vision_config.tokens_per_second", 2); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + LOAD_ARG_OR( + rope_scaling_rope_type, "vision_config.rope_scaling.type", "mrope"); + LOAD_ARG(rope_scaling_mrope_section, "rope_scaling.mrope_section"); + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); +}); +} // namespace xllm diff --git a/xllm/models/vlm/npu/qwen3_vl.h b/xllm/models/vlm/npu/qwen3_vl.h new file mode 100644 index 000000000..dff1ae9eb --- /dev/null +++ b/xllm/models/vlm/npu/qwen3_vl.h @@ -0,0 +1,798 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/layers/lm_head.h" +#include "core/layers/qwen3_vision_encode_layer.h" +#include "core/layers/rms_norm.h" +#include "models/llm/npu/qwen3.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/qwen2_vl_image_processor.h" +#include "qwen2_5_vl.h" +#include "xllm_kernels/core/include/atb_speed/log.h" + +namespace xllm { + +#define PrintTensor(tensor) print_tensor(tensor, #tensor, 10, true, false); + +class Qwen3_VisionPatchEmbedImpl : public torch::nn::Module { + public: + Qwen3_VisionPatchEmbedImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + auto in_features = model_args.mm_num_channels() * + model_args.mm_temporal_patch_size() * + model_args.mm_patch_size() * model_args.mm_patch_size(); + + auto out_features = model_args.mm_hidden_size(); + + proj_ = register_module( + "proj", + torch::nn::Linear( + torch::nn::LinearOptions(in_features, out_features).bias(true))); + + proj_->weight.set_data(proj_->weight.to(options)); + proj_->bias.set_data(proj_->bias.to(options)); + } + + torch::Tensor forward(torch::Tensor x) { return proj_(x); } + + void load_state_dict(const StateDict& state_dict) { + auto weight = state_dict.get_tensor("proj.weight"); + if (weight.defined()) { + weight = weight.reshape({weight.size(0), -1}); + DCHECK_EQ(proj_->weight.sizes(), weight.sizes()) + << "proj weight size mismatch for " << name(); + proj_->weight.data().copy_(weight); + proj_weight_loaded_ = true; + } + auto bias = state_dict.get_tensor("proj.bias"); + if (bias.defined()) { + bias = bias.reshape({bias.size(0)}); + DCHECK_EQ(proj_->bias.sizes(), bias.sizes()) + << "proj bias size mismatch for " << name(); + proj_->bias.data().copy_(bias); + proj_bias_loaded_ = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(proj_weight_loaded_) + << "weight is not loaded for " << prefix + "proj.weight"; + CHECK(proj_bias_loaded_) + << "bias is not loaded for " << prefix + "proj.bias"; + } + + private: + bool proj_weight_loaded_ = false; + bool proj_bias_loaded_ = false; + torch::nn::Linear proj_{nullptr}; +}; +TORCH_MODULE(Qwen3_VisionPatchEmbed); + +class Qwen3_VisionBlockImpl : public torch::nn::Module { + public: + Qwen3_VisionBlockImpl(const ModelContext& context) { + // register submodules + encoder_layer_ = register_module("encoder_layer", + layer::Qwen3VisionEncoderLayer(context)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& m_cos_pos, + torch::Tensor& m_sin_pos, + torch::Tensor& cu_seq_len, + std::vector& cu_seq_len_vec, + ModelInputParams& input_params, + int node_id) { + return encoder_layer_(x, + m_cos_pos, + m_sin_pos, + cu_seq_len, + cu_seq_len_vec, + input_params, + node_id); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + encoder_layer_->load_state_dict(state_dict); + } + + void verify_loaded_weights(const std::string& prefix) const { + encoder_layer_->verify_loaded_weights(); + } + void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } + + private: + layer::Qwen3VisionEncoderLayer encoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen3_VisionBlock); + +class Qwen3_VisionRotaryEmbeddingImpl : public torch::nn::Module { + public: + Qwen3_VisionRotaryEmbeddingImpl(const ModelContext& context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + dim_ = model_args.mm_head_dim() / 2; + theta_ = 10000.0; + + auto opts = options.dtype(torch::kFloat32); + auto inv_freq = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, opts) / dim_); + inv_freq_ = register_buffer("inv_freq", inv_freq); + } + + void update_freqs_cache(int64_t seqlen) { + if (seqlen <= seq_len_cached_) return; + + seqlen *= 2; + seq_len_cached_ = seqlen; + + auto options = torch::TensorOptions() + .dtype(torch::kFloat32) + .device(inv_freq_.device()); + inv_freq_ = + 1.0 / torch::pow(theta_, torch::arange(0, dim_, 2, options) / dim_); + auto seq = torch::arange(seqlen, options); + freqs_cached_ = torch::outer(seq, inv_freq_); + } + + torch::Tensor forward(int seqlen) { + update_freqs_cache(seqlen); + return freqs_cached_.slice(0, 0, seqlen); + } + + private: + int dim_ = 0; + double theta_ = 0.0; + + int64_t seq_len_cached_ = 0; + torch::Tensor inv_freq_; + torch::Tensor freqs_cached_; +}; +TORCH_MODULE(Qwen3_VisionRotaryEmbedding); + +class Qwen3_VisionPatchMergerImpl : public torch::nn::Module { + public: + Qwen3_VisionPatchMergerImpl(const ModelContext& context, + bool use_postshuffle_norm = false) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto quant_args = context.get_quant_args(); + auto parallel_args = context.get_parallel_args(); + int64_t d_model = model_args.mm_projection_dim(); + int context_dim = model_args.mm_hidden_size(); + int spatial_merge_size = model_args.mm_spatial_merge_size(); + hidden_size_ = + context_dim * static_cast(std::pow(spatial_merge_size, 2)); + use_postshuffle_norm_ = use_postshuffle_norm; + if (use_postshuffle_norm) + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({hidden_size_}) + .elementwise_affine(true) + .eps(1e-6))); + else + norm_ = register_module( + "norm", + torch::nn::LayerNorm(torch::nn::LayerNormOptions({context_dim}) + .elementwise_affine(true) + .eps(1e-6))); + norm_->weight.set_data(norm_->weight.to(options)); + norm_->bias.set_data(norm_->bias.to(options)); + + auto fc1 = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, hidden_size_).bias(true)); + fc1->weight.set_data(fc1->weight.to(options)); + fc1->bias.set_data(fc1->bias.to(options)); + auto act = torch::nn::GELU(); + auto fc2 = torch::nn::Linear( + torch::nn::LinearOptions(hidden_size_, d_model).bias(true)); + fc2->weight.set_data(fc2->weight.to(options)); + fc2->bias.set_data(fc2->bias.to(options)); + mlp_ = register_module("mlp", torch::nn::Sequential(fc1, act, fc2)); + layers_ = std::make_tuple(fc1, act, fc2); + } + + torch::Tensor forward(torch::Tensor x) { + if (use_postshuffle_norm_) + x = norm_(x.view({-1, hidden_size_})); + else + x = norm_(x).view({-1, hidden_size_}); + return mlp_->forward(x); + } + + void load_state_dict(const StateDict& state_dict) { + // norm + const auto& norm_dict = state_dict.get_dict_with_prefix("norm."); + const auto& norm_weight = norm_dict.get_tensor("weight"); + if (norm_weight.defined()) { + CHECK_EQ(norm_->weight.sizes(), norm_weight.sizes()) + << "weight size mismatch for " << name(); + norm_->weight.data().copy_(norm_weight); + is_norm_weight_loaded = true; + } + const auto norm_bias = norm_dict.get_tensor("bias"); + if (norm_bias.defined()) { + CHECK_EQ(norm_->bias.sizes(), norm_bias.sizes()) + << "bias size mismatch for " << name(); + norm_->bias.data().copy_(norm_bias); + is_norm_bias_loaded = true; + } + + const auto& fc1_dict = state_dict.get_dict_with_prefix("linear_fc1."); + const auto& fc1_weight = fc1_dict.get_tensor("weight"); + if (fc1_weight.defined()) { + CHECK_EQ(std::get<0>(layers_)->weight.sizes(), fc1_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<0>(layers_)->weight.data().copy_(fc1_weight); + is_fc1_weight_loaded = true; + } + const auto fc1_bias = fc1_dict.get_tensor("bias"); + if (fc1_bias.defined()) { + CHECK_EQ(std::get<0>(layers_)->bias.sizes(), fc1_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<0>(layers_)->bias.data().copy_(fc1_bias); + is_fc1_bias_loaded = true; + } + + const auto& fc2_dict = state_dict.get_dict_with_prefix("linear_fc2."); + const auto& fc2_weight = fc2_dict.get_tensor("weight"); + if (fc2_weight.defined()) { + CHECK_EQ(std::get<2>(layers_)->weight.sizes(), fc2_weight.sizes()) + << "weight size mismatch for " << name(); + std::get<2>(layers_)->weight.data().copy_(fc2_weight); + is_fc2_weight_loaded = true; + } + const auto fc2_bias = fc2_dict.get_tensor("bias"); + if (fc2_bias.defined()) { + CHECK_EQ(std::get<2>(layers_)->bias.sizes(), fc2_bias.sizes()) + << "bias size mismatch for " << name(); + std::get<2>(layers_)->bias.data().copy_(fc2_bias); + is_fc2_bias_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + CHECK(is_fc1_weight_loaded) + << "weight is not loaded for " << prefix + "linear_fc1" + ".weight"; + CHECK(is_fc1_bias_loaded) + << "bias is not loaded for " << prefix + "linear_fc1" + ".bias"; + CHECK(is_fc2_weight_loaded) + << "weight is not loaded for " << prefix + "linear_fc2" + ".weight"; + CHECK(is_fc2_bias_loaded) + << "bias is not loaded for " << prefix + "linear_fc2" + ".bias"; + CHECK(is_norm_weight_loaded) + << "weight is not loaded for " << prefix + "norm" + ".weight"; + CHECK(is_norm_bias_loaded) + << "bias is not loaded for " << prefix + "norm" + ".bias"; + } + + private: + int hidden_size_; + bool use_postshuffle_norm_; + torch::nn::LayerNorm norm_{nullptr}; + torch::nn::Sequential mlp_{nullptr}; + std::tuple layers_ = { + nullptr, + nullptr, + nullptr}; + bool is_fc1_weight_loaded = false; + bool is_fc1_bias_loaded = false; + bool is_fc2_weight_loaded = false; + bool is_fc2_bias_loaded = false; + bool is_norm_weight_loaded = false; + bool is_norm_bias_loaded = false; +}; +TORCH_MODULE(Qwen3_VisionPatchMerger); + +class Qwen3_VisionTransformerImpl : public torch::nn::Module { + public: + Qwen3_VisionTransformerImpl(const ModelContext& context) + : options_(context.get_tensor_options()) { + auto model_args = context.get_model_args(); + hidden_size_ = model_args.mm_hidden_size(); + num_heads_ = model_args.mm_num_attention_heads(); + window_size_ = model_args.mm_window_size(); + patch_size_ = model_args.mm_patch_size(); + spatial_merge_size_ = model_args.mm_spatial_merge_size(); + auto& visual_indexes = model_args.mm_deepstack_visual_indexes(); + deepstack_visual_indexes_.insert(deepstack_visual_indexes_.end(), + visual_indexes.begin(), + visual_indexes.end()); + num_position_embeddings_ = model_args.mm_num_position_embeddings(); + spatial_merge_unit_ = + static_cast(spatial_merge_size_ * spatial_merge_size_); + num_grid_per_side_ = static_cast(std::sqrt(num_position_embeddings_)); + + patch_embed_ = + register_module("patch_embed", Qwen3_VisionPatchEmbed(context)); + rotary_pos_emb_ = + register_module("rotary_pos_emb", Qwen3_VisionRotaryEmbedding(context)); + + blocks_ = register_module("blocks", torch::nn::ModuleList()); + deepstack_mergers_ = + register_module("deepstack_mergers", torch::nn::ModuleList()); + + emb_ = register_module( + "embedding", + torch::nn::Embedding(num_position_embeddings_, hidden_size_)); + emb_->weight.set_data(emb_->weight.to(options_)); + + merger_ = register_module("merger", Qwen3_VisionPatchMerger(context)); + + for (int32_t idx = 0; idx < model_args.mm_num_hidden_layers(); idx++) { + auto block = Qwen3_VisionBlock(context); + blocks_->push_back(block); + layers_.push_back(block); + } + for (int32_t idx = 0; idx < deepstack_visual_indexes_.size(); idx++) { + auto merger = Qwen3_VisionPatchMerger(context, true); + deepstack_mergers_->push_back(merger); + deepstack_merger_layers_.push_back(merger); + } + } + + torch::Tensor rot_pos_emb(torch::Tensor grid_thw) { + std::vector pos_ids_vec; + auto count = grid_thw.sizes()[0]; + pos_ids_vec.reserve(count); + + auto grid_thw_cpu = grid_thw.cpu(); + auto options = + torch::TensorOptions().dtype(torch::kLong).device(grid_thw.device()); + + for (int idx = 0; idx < count; ++idx) { + auto t = grid_thw_cpu[idx][0].item(); + auto h = grid_thw_cpu[idx][1].item(); + auto w = grid_thw_cpu[idx][2].item(); + + auto hpos_ids = torch::arange(h, options).unsqueeze(1).expand({-1, w}); + hpos_ids = hpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + auto wpos_ids = torch::arange(w, options).unsqueeze(0).expand({h, -1}); + wpos_ids = wpos_ids + .reshape({h / spatial_merge_size_, + spatial_merge_size_, + w / spatial_merge_size_, + spatial_merge_size_}) + .permute({0, 2, 1, 3}) + .flatten(); + + pos_ids_vec.push_back( + torch::stack({hpos_ids, wpos_ids}, -1).repeat({t, 1})); + } + + auto pos_ids = torch::cat(pos_ids_vec, 0); + auto max_grid_size = + grid_thw + .index({torch::indexing::Slice(), + torch::indexing::Slice(1, torch::indexing::None)}) + .max(); + + auto rotary_pos_emb_full = rotary_pos_emb_(max_grid_size.item()); + auto rotary_pos_emb = rotary_pos_emb_full.index({pos_ids}).flatten(1); + + return rotary_pos_emb; + } + + torch::Tensor fast_pos_embed_interpolate(const torch::Tensor& grid_thw) { + auto device = grid_thw.device(); + int64_t hidden_dim = hidden_size_; + int64_t m_size = spatial_merge_size_; + + auto grid_cpu = grid_thw.to(torch::kCPU); + int64_t count = grid_thw.size(0); + + std::vector outputs; + outputs.reserve(count); + + for (int64_t idx = 0; idx < count; ++idx) { + int64_t t = grid_cpu[idx][0].item(); + int64_t h = grid_cpu[idx][1].item(); + int64_t w = grid_cpu[idx][2].item(); + + auto h_idxs = + torch::linspace( + 0, static_cast(num_grid_per_side_ - 1), h, torch::kFloat32) + .to(device); + auto w_idxs = + torch::linspace( + 0, static_cast(num_grid_per_side_ - 1), w, torch::kFloat32) + .to(device); + + auto h_floor = h_idxs.to(torch::kLong); + auto w_floor = w_idxs.to(torch::kLong); + auto h_ceil = torch::clamp(h_floor + 1, 0, num_grid_per_side_ - 1); + auto w_ceil = torch::clamp(w_floor + 1, 0, num_grid_per_side_ - 1); + + auto dh = h_idxs - h_floor; + auto dw = w_idxs - w_floor; + + auto mesh_d = torch::meshgrid({dh, dw}, "ij"); + auto dh_grid = mesh_d[0], dw_grid = mesh_d[1]; + + auto mesh_floor = torch::meshgrid({h_floor, w_floor}, "ij"); + auto h_floor_grid = mesh_floor[0]; + auto w_floor_grid = mesh_floor[1]; + + auto mesh_ceil = torch::meshgrid({h_ceil, w_ceil}, "ij"); + auto h_ceil_grid = mesh_ceil[0]; + auto w_ceil_grid = mesh_ceil[1]; + + auto h_floor_grid_idx = h_floor_grid * num_grid_per_side_; + auto h_ceil_grid_idx = h_ceil_grid * num_grid_per_side_; + + auto w11 = dh_grid * dw_grid; + auto w10 = dh_grid - w11; + auto w01 = dw_grid - w11; + auto w00 = 1.0f - dh_grid - dw_grid + w11; + + auto idx00 = h_floor_grid_idx + w_floor_grid; + auto idx01 = h_floor_grid_idx + w_ceil_grid; + auto idx10 = h_ceil_grid_idx + w_floor_grid; + auto idx11 = h_ceil_grid_idx + w_ceil_grid; + + auto indices = torch::stack({idx00, idx01, idx10, idx11}, 0) + .reshape({4, -1}) + .to(torch::kLong); + auto weights = torch::stack({w00, w01, w10, w11}, 0) + .reshape({4, -1, 1}) + .to(options_); + + auto embeds = emb_(indices); + + auto combined = (embeds * weights).sum(0); // [h*w, hidden_dim] + + auto repeated = combined.unsqueeze(0).expand({t, -1, -1}).contiguous(); + repeated = repeated.view( + {t, h / m_size, m_size, w / m_size, m_size, hidden_dim}); + repeated = repeated.permute({0, 1, 3, 2, 4, 5}).reshape({-1, hidden_dim}); + + outputs.push_back(repeated); + } + + return torch::cat(outputs, 0); + } + + std::tuple> forward( + torch::Tensor hidden_states, + torch::Tensor grid_thw, // [batch,thw] + const ModelInputParams& input_params) { + hidden_states = patch_embed_(hidden_states); + auto pos_embeds = fast_pos_embed_interpolate(grid_thw); + hidden_states = hidden_states + pos_embeds; + // compute position embedding + auto rotary_pos_emb = rot_pos_emb(grid_thw); + // compute cu_seqlens + auto cu_seqlens = torch::repeat_interleave( + grid_thw.index({torch::indexing::Slice(), 1}) * + grid_thw.index({torch::indexing::Slice(), 2}), + grid_thw.index({torch::indexing::Slice(), 0})) + .cumsum(0, torch::kInt32); + namespace F = torch::nn::functional; + cu_seqlens = F::pad( + cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); + + // transformers + cu_seqlens = torch::diff(cu_seqlens); + + m_cos = rotary_pos_emb.cos().type_as(hidden_states); + m_cos = m_cos.repeat({1, 2}); + m_sin = rotary_pos_emb.sin().type_as(hidden_states); + m_sin = m_sin.repeat({1, 2}); + + ModelInputParams& input_params_new = + const_cast(input_params); + torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu(); + std::vector cu_seqlens_vec( + cu_seqlens_cpu.data_ptr(), // full seqlen vec + cu_seqlens_cpu.data_ptr() + cu_seqlens_cpu.numel()); + std::vector deepstack_feature_lists; + deepstack_feature_lists.reserve(deepstack_visual_indexes_.size()); + for (int idx = 0; idx < blocks_->size(); ++idx) { + hidden_states = layers_[idx](hidden_states, + m_cos, + m_sin, + cu_seqlens, + cu_seqlens_vec, + input_params_new, + idx); + auto it = std::find(deepstack_visual_indexes_.begin(), + deepstack_visual_indexes_.end(), + idx); + + if (it != deepstack_visual_indexes_.end()) { + int index = std::distance(deepstack_visual_indexes_.begin(), it); + deepstack_feature_lists.push_back( + deepstack_merger_layers_[index](hidden_states)); + } + } + // adapter + hidden_states = merger_(hidden_states); + return std::make_tuple(hidden_states, deepstack_feature_lists); + } + + void load_state_dict(const StateDict& state_dict) { + patch_embed_->load_state_dict( + state_dict.get_dict_with_prefix("patch_embed.")); + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->load_state_dict(state_dict.get_dict_with_prefix( + "blocks." + std::to_string(idx) + ".")); + } + + merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); + + for (int idx = 0; idx < deepstack_merger_layers_.size(); ++idx) { + deepstack_merger_layers_[idx]->load_state_dict( + state_dict.get_dict_with_prefix("deepstack_merger_list." + + std::to_string(idx) + ".")); + } + + const auto& emb_dict = state_dict.get_dict_with_prefix("pos_embed."); + const auto& emb_weight = emb_dict.get_tensor("weight"); + if (emb_weight.defined()) { + CHECK_EQ(emb_->weight.sizes(), emb_weight.sizes()) + << "weight size mismatch for " << name(); + emb_->weight.data().copy_(emb_weight); + is_emb_weight_loaded = true; + } + } + + void verify_loaded_weights(const std::string& prefix) const { + patch_embed_->verify_loaded_weights(prefix + "patch_embed."); + for (int idx = 0; idx < blocks_->size(); ++idx) { + layers_[idx]->verify_loaded_weights(prefix + "blocks." + + std::to_string(idx) + "."); + } + merger_->verify_loaded_weights(prefix + "merger."); + + for (int idx = 0; idx < deepstack_merger_layers_.size(); ++idx) { + deepstack_merger_layers_[idx]->verify_loaded_weights( + "deepstack_merger_list." + std::to_string(idx) + "."); + } + CHECK(is_emb_weight_loaded) + << "weight is not loaded for " << prefix + "" + ".bias"; + } + + void merge_loaded_weights() { + for (int idx = 0; idx < layers_.size(); ++idx) { + layers_[idx]->merge_loaded_weights(); + } + } + + private: + int hidden_size_ = 0; + int num_heads_ = 0; + int window_size_ = 0; + int patch_size_ = 0; + int spatial_merge_size_ = 0; + std::vector deepstack_visual_indexes_; + int spatial_merge_unit_ = 0; + int64_t num_position_embeddings_ = 0; + int num_grid_per_side_ = 0; + + Qwen3_VisionPatchEmbed patch_embed_{nullptr}; + Qwen3_VisionRotaryEmbedding rotary_pos_emb_{nullptr}; + torch::nn::Embedding emb_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + + torch::nn::ModuleList deepstack_mergers_{nullptr}; + std::vector deepstack_merger_layers_; + Qwen3_VisionPatchMerger merger_{nullptr}; + + torch::Tensor m_cos; + torch::Tensor m_sin; + int device_id = 0; + bool is_emb_weight_loaded = false; + torch::TensorOptions options_; +}; +TORCH_MODULE(Qwen3_VisionTransformer); + +struct Qwen3_VLImageInputs { + torch::Tensor pixel_values; + torch::Tensor image_grid_thw; +}; + +struct Qwen3_VLVideoInputs { + torch::Tensor pixel_values_videos; + torch::Tensor video_grid_thw; + torch::Tensor second_per_grid_ts; +}; + +class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { + public: + Qwen3_VLForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Qwen3_VisionTransformer(context)); + language_model_ = + register_module("language_model", QWen3ForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + // visual + auto [image_embeds, deep_stacks] = + visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + input_params.deep_stacks = deep_stacks; + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); + input_params.visual_pos_masks = is_multimodal; + inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; + + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } + + // verify + visual_->verify_loaded_weights("model.visual."); + visual_->merge_loaded_weights(); + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader), "model.language_model."); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + + Qwen3_VisionTransformer visual_{nullptr}; + QWen3ForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen3_VLForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(qwen3_vl, Qwen2_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(qwen3_vl, Qwen3_VLForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(qwen3_vl, Qwen2VLImageProcessor); + +REGISTER_MODEL_ARGS(qwen3_vl, [&] { + // text config + // LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0); + LOAD_ARG_OR(model_type, "model_type", "qwen3_vl"); + LOAD_ARG_OR(bos_token_id, "text_config.bos_token_id", 151643); + LOAD_ARG_OR(eos_token_id, "text_config.eos_token_id", 151645); + LOAD_ARG_OR( + vision_start_token_id, "text_config.vision_start_token_id", 151652); + LOAD_ARG_OR(vision_end_token_id, "text_config.vision_end_token_id", 151653); + LOAD_ARG_OR(vision_token_id, "text_config.vision_token_id", 151654); + LOAD_ARG_OR(image_token_id, "text_config.image_token_id", 151655); + LOAD_ARG_OR(video_token_id, "text_config.video_token_id", 151656); + LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 3584); + LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 18944); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 128000); + LOAD_ARG_OR(max_window_layers, "text_config.max_window_layers", 28); + LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 32); + LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 48); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 4); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-06); + LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 5000000.0f); + LOAD_ARG_OR(sliding_window, "text_config.sliding_window", 32768); + LOAD_ARG_OR(tie_word_embeddings, "text_config.tie_word_embeddings", false); + LOAD_ARG(rope_scaling_mrope_section, + "text_config.rope_scaling.mrope_section"); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + // LOAD_ARG_OR(transformers_version, "transformers_version", "4.41.2"); + // LOAD_ARG_OR(use_cache, "use_cache", true); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR_FUNC(head_dim, "text_config.head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + // vision_config + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 27); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "gelu_pytorch_tanh"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1152); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4304); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 4096); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 16); + LOAD_ARG_OR(mm_num_position_embeddings, + "vision_config.num_position_embeddings", + 2304); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG(mm_deepstack_visual_indexes, + "vision_config.deepstack_visual_indexes"); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + LOAD_ARG_OR( + rope_scaling_rope_type, "vision_config.rope_scaling.type", "mrope"); + + LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151936); +}); +} // namespace xllm diff --git a/xllm/models/vlm/npu/qwen3_vl_moe.h b/xllm/models/vlm/npu/qwen3_vl_moe.h new file mode 100644 index 000000000..a9de36647 --- /dev/null +++ b/xllm/models/vlm/npu/qwen3_vl_moe.h @@ -0,0 +1,206 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model_context.h" +#include "core/layers/lm_head.h" +#include "core/layers/qwen3_vision_encode_layer.h" +#include "core/layers/rms_norm.h" +#include "models/llm/npu/qwen3_moe.h" +#include "models/model_registry.h" +#include "processors/input_processor.h" +#include "processors/qwen2_vl_image_processor.h" +#include "qwen2_5_vl.h" +#include "qwen3_vl.h" +#include "xllm_kernels/core/include/atb_speed/log.h" + +namespace xllm { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { + public: + Qwen3_VLMoeForConditionalGenerationImpl(const ModelContext& context) + : model_args_(context.get_model_args()), + options_(context.get_tensor_options()) { + visual_ = register_module("visual", Qwen3_VisionTransformer(context)); + + language_model_ = + register_module("language_model", Qwen3MoeForCausalLM(context)); + } + + torch::Tensor get_input_embeddings( + torch::Tensor input_ids, + const std::optional& image_input, + const std::optional& video_input, + const ModelInputParams& input_params) { + auto inputs_embeds = language_model_->get_input_embeddings(input_ids); + if (image_input) { + // visual + auto [image_embeds, deep_stacks] = + visual_(image_input->pixel_values.to(options_), + image_input->image_grid_thw, + input_params); + input_params.deep_stacks = deep_stacks; + // merge + auto is_multimodal = torch::isin(input_ids, model_args_.image_token_id()); + input_params.visual_pos_masks = is_multimodal; + inputs_embeds.index_put_({is_multimodal}, image_embeds); + } + return inputs_embeds; + } + + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + torch::NoGradGuard no_grad; + const auto& mm_data = input_params.mm_data; + torch::Tensor pixel_values; + if (const auto& res = mm_data.get("pixel_values")) + pixel_values = res.value(); + + torch::Tensor image_grid_thw; + if (const auto& res = mm_data.get("image_grid_thw")) + image_grid_thw = res.value(); + std::optional image_inputs; + std::optional video_inputs; + + if (pixel_values.defined() && image_grid_thw.defined()) + image_inputs = Qwen3_VLImageInputs{pixel_values, image_grid_thw}; + + auto inputs_embeds = + get_input_embeddings(tokens, image_inputs, video_inputs, input_params); + input_params.input_embedding = inputs_embeds; + auto emb = language_model_(tokens, positions, kv_caches, input_params); + + return emb; + } + + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + return language_model_->logits(hidden_states, seleted_idxes); + } + + void load_model(std::unique_ptr loader) { + for (const auto& state_dict : loader->get_state_dicts()) { + visual_->load_state_dict( + state_dict->get_dict_with_prefix("model.visual.")); + } + visual_->verify_loaded_weights("model.visual."); + visual_->merge_loaded_weights(); + + if (!model_args_.image_embedding_mode()) { + language_model_->load_model(std::move(loader), "model.language_model."); + } + } + + layer::LmHead get_lm_head() { return language_model_->get_lm_head(); } + void set_lm_head(layer::LmHead& head) { language_model_->set_lm_head(head); } + + layer::WordEmbedding get_word_embedding() { + return language_model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + language_model_->set_word_embedding(word_embedding); + } + + private: + ModelArgs model_args_; + torch::TensorOptions options_; + Qwen3_VisionTransformer visual_{nullptr}; + Qwen3MoeForCausalLM language_model_{nullptr}; +}; +TORCH_MODULE(Qwen3_VLMoeForConditionalGeneration); + +REGISTER_INPUT_PROCESSOR(qwen3_vl_moe, Qwen2_5_VLInputProcessor); +REGISTER_CAUSAL_VLM_MODEL(qwen3_vl_moe, Qwen3_VLMoeForConditionalGeneration); +REGISTER_IMAGE_PROCESSOR(qwen3_vl_moe, Qwen2VLImageProcessor); +// register the model args +REGISTER_MODEL_ARGS(qwen3_vl_moe, [&] { + // text config + LOAD_ARG_OR(model_type, "model_type", "qwen3_vl_moe"); + LOAD_ARG_OR(attention_bias, "text_config.attention_bias", false); + LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0f); + LOAD_ARG_OR(bos_token_id, "text_config.bos_token_id", 151643); + LOAD_ARG_OR(decoder_sparse_step, "text_config.decoder_sparse_step", 1); + LOAD_ARG_OR(dtype, "text_config.dtype", "bfloat16"); + LOAD_ARG_OR(eos_token_id, "text_config.eos_token_id", 151645); + LOAD_ARG_OR_FUNC(head_dim, "text_config.head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + LOAD_ARG_OR(hidden_act, "text_config.hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "text_config.hidden_size", 2048); + LOAD_ARG_OR(initializer_range, "text_config.initializer_range", 0.02); + LOAD_ARG_OR(intermediate_size, "text_config.intermediate_size", 5632); + LOAD_ARG_OR( + max_position_embeddings, "text_config.max_position_embeddings", 128000); + // LOAD_ARG(mlp_only_layers, "text_config.mlp_only_layers"); + LOAD_ARG_OR(moe_intermediate_size, "text_config.moe_intermediate_size", 1408); + LOAD_ARG_OR(norm_topk_prob, "text_config.norm_topk_prob", true); + LOAD_ARG_OR(n_heads, "text_config.num_attention_heads", 16); + LOAD_ARG_OR(num_experts, "text_config.num_experts", 128); + LOAD_ARG_OR(num_experts_per_tok, "text_config.num_experts_per_tok", 8); + LOAD_ARG_OR(n_layers, "text_config.num_hidden_layers", 24); + LOAD_ARG_OR(n_kv_heads, "text_config.num_key_value_heads", 16); + LOAD_ARG_OR(rms_norm_eps, "text_config.rms_norm_eps", 1e-06); + LOAD_ARG_OR(rope_scaling_rope_type, "text_config.rope_scaling.type", "mrope"); + LOAD_ARG(rope_scaling_mrope_section, + "text_config.rope_scaling.mrope_section"); + // LOAD_ARG_OR(rope_scaling_mrope_interleaved,"text_config.rope_scaling.mrope_interleaved",true); + LOAD_ARG_OR(rope_theta, "text_config.rope_theta", 5000000.0f); + LOAD_ARG_OR(vocab_size, "text_config.vocab_size", 151936); + + // vision config + LOAD_ARG(mm_deepstack_visual_indexes, + "vision_config.deepstack_visual_indexes"); + LOAD_ARG_OR(mm_num_hidden_layers, "vision_config.depth", 27); + LOAD_ARG_OR(mm_hidden_act, "vision_config.hidden_act", "gelu_pytorch_tanh"); + LOAD_ARG_OR(mm_hidden_size, "vision_config.hidden_size", 1152); + LOAD_ARG_OR(mm_num_channels, "vision_config.in_channels", 3); + LOAD_ARG_OR(mm_initializer_range, "vision_config.initializer_range", 0.02); + LOAD_ARG_OR(mm_intermediate_size, "vision_config.intermediate_size", 4304); + LOAD_ARG_OR(mm_num_attention_heads, "vision_config.num_heads", 16); + LOAD_ARG_OR(mm_num_position_embeddings, + "vision_config.num_position_embeddings", + 2304); + LOAD_ARG_OR(mm_projection_dim, "vision_config.out_hidden_size", 3584); + LOAD_ARG_OR(mm_patch_size, "vision_config.patch_size", 16); + LOAD_ARG_OR(mm_spatial_merge_size, "vision_config.spatial_merge_size", 2); + LOAD_ARG_OR(mm_temporal_patch_size, "vision_config.temporal_patch_size", 2); + LOAD_ARG_OR_FUNC(mm_head_dim, "head_dim", [&] { + return args->mm_hidden_size() / args->mm_num_attention_heads(); + }); + + LOAD_ARG_OR(image_token_id, "image_token_id", 151655); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(video_token_id, "video_token_id", 151656); + LOAD_ARG_OR(vision_end_token_id, "vision_end_token_id", 151653); + LOAD_ARG_OR(vision_start_token_id, "vision_start_token_id", 151652); +}); +} // namespace xllm diff --git a/xllm/models/vlm/qwen2_5_vl.h b/xllm/models/vlm/qwen2_5_vl.h index c6bd62bb8..8edcfaa8d 100644 --- a/xllm/models/vlm/qwen2_5_vl.h +++ b/xllm/models/vlm/qwen2_5_vl.h @@ -15,12 +15,6 @@ limitations under the License. #pragma once -#if defined(USE_NPU) -#include - -#include "xllm_kernels/core/include/atb_speed/log.h" -#endif - #include #include @@ -178,13 +172,6 @@ class Qwen2_5_VisionBlockImpl : public torch::nn::Module { encoder_layer_->load_state_dict(state_dict); } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - encoder_layer_->verify_loaded_weights(); - } - void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } -#endif - private: layer::Qwen2dot5VisionEncoderLayer encoder_layer_{nullptr}; }; @@ -293,13 +280,9 @@ class Qwen2_5_VisionPatchMergerImpl : public torch::nn::Module { hidden_size_ = context_dim * static_cast(std::pow(spatial_merge_size, 2)); -#if defined(USE_NPU) - ln_q_ = register_module("ln_q", layer::RmsNorm(context)); -#else ln_q_ = register_module( "ln_q", layer::RmsNorm(context_dim, model_args.rms_norm_eps(), options)); -#endif auto cpl = torch::nn::Linear( torch::nn::LinearOptions(hidden_size_, hidden_size_).bias(true)); @@ -315,11 +298,7 @@ class Qwen2_5_VisionPatchMergerImpl : public torch::nn::Module { } torch::Tensor forward(torch::Tensor x) { -#if defined(USE_NPU) - x = ln_q_(x, 0); -#elif defined(USE_MLU) x = ln_q_(x); -#endif x = x.view({-1, hidden_size_}); return mlp_->forward(x); } @@ -360,22 +339,6 @@ class Qwen2_5_VisionPatchMergerImpl : public torch::nn::Module { } } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - ln_q_->verify_loaded_weights(prefix + "ln_q."); - CHECK(is_cpl_weight_loaded) - << "weight is not loaded for " << prefix + "mlp.0" + ".weight"; - CHECK(is_cpl_bias_loaded) - << "bias is not loaded for " << prefix + "mlp.0" + ".bias"; - CHECK(is_rpl_weight_loaded) - << "weight is not loaded for " << prefix + "mlp.2" + ".weight"; - CHECK(is_rpl_bias_loaded) - << "bias is not loaded for " << prefix + "mlp.2" + ".bias"; - } - - void merge_loaded_weights() { ln_q_->merge_loaded_weights(); } -#endif - private: int64_t hidden_size_; @@ -585,17 +548,6 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { m_cos = rotary_pos_emb.cos().type_as(hidden_states); m_sin = rotary_pos_emb.sin().type_as(hidden_states); -#if defined(USE_NPU) - // transformers - cu_seqlens = torch::diff(cu_seqlens); - cu_window_seqlens = torch::diff(cu_window_seqlens); - - m_cos = torch::nn::functional::pad( - m_cos, torch::nn::functional::PadFuncOptions({0, 24})); - m_sin = torch::nn::functional::pad( - m_sin, torch::nn::functional::PadFuncOptions({0, 24})); -#endif - m_cos = m_cos.repeat({1, 2}); m_sin = m_sin.repeat({1, 2}); ModelInputParams& input_params_new = @@ -646,24 +598,6 @@ class Qwen2_5_VisionTransformerImpl : public torch::nn::Module { merger_->load_state_dict(state_dict.get_dict_with_prefix("merger.")); } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - patch_embed_->verify_loaded_weights(prefix + "patch_embed."); - for (int idx = 0; idx < blocks_->size(); ++idx) { - layers_[idx]->verify_loaded_weights(prefix + "blocks." + - std::to_string(idx) + "."); - } - merger_->verify_loaded_weights(prefix + "merger."); - } - - void merge_loaded_weights() { - for (int idx = 0; idx < blocks_->size(); ++idx) { - layers_[idx]->merge_loaded_weights(); - } - merger_->merge_loaded_weights(); - } -#endif - private: int hidden_size_ = 0; int num_heads_ = 0; @@ -764,11 +698,6 @@ class Qwen2_5_VLForConditionalGenerationImpl : public torch::nn::Module { visual_->load_state_dict(state_dict->get_dict_with_prefix("visual.")); } -#if defined(USE_NPU) - // verify - visual_->verify_loaded_weights("visual."); - visual_->merge_loaded_weights(); -#endif if (!model_args_.image_embedding_mode()) { language_model_->load_model(std::move(loader)); } diff --git a/xllm/models/vlm/qwen3_vl.h b/xllm/models/vlm/qwen3_vl.h index c4c2896a1..b15e92f60 100755 --- a/xllm/models/vlm/qwen3_vl.h +++ b/xllm/models/vlm/qwen3_vl.h @@ -15,11 +15,6 @@ limitations under the License. #pragma once -#if defined(USE_NPU) -#include - -#include "xllm_kernels/core/include/atb_speed/log.h" -#endif #include #include #include @@ -127,13 +122,6 @@ class Qwen3_VisionBlockImpl : public torch::nn::Module { encoder_layer_->load_state_dict(state_dict); } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - encoder_layer_->verify_loaded_weights(); - } - void merge_loaded_weights() { encoder_layer_->merge_loaded_weights(); } -#endif - private: layer::Qwen3VisionEncoderLayer encoder_layer_{nullptr}; }; @@ -512,11 +500,6 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { cu_seqlens = F::pad( cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0)); -#if defined(USE_NPU) - // transformers - cu_seqlens = torch::diff(cu_seqlens); -#endif - m_cos = rotary_pos_emb.cos().type_as(hidden_states); m_cos = m_cos.repeat({1, 2}); m_sin = rotary_pos_emb.sin().type_as(hidden_states); @@ -579,30 +562,6 @@ class Qwen3_VisionTransformerImpl : public torch::nn::Module { } } -#if defined(USE_NPU) - void verify_loaded_weights(const std::string& prefix) const { - patch_embed_->verify_loaded_weights(prefix + "patch_embed."); - for (int idx = 0; idx < blocks_->size(); ++idx) { - layers_[idx]->verify_loaded_weights(prefix + "blocks." + - std::to_string(idx) + "."); - } - merger_->verify_loaded_weights(prefix + "merger."); - - for (int idx = 0; idx < deepstack_merger_layers_.size(); ++idx) { - deepstack_merger_layers_[idx]->verify_loaded_weights( - "deepstack_merger_list." + std::to_string(idx) + "."); - } - CHECK(is_emb_weight_loaded) - << "weight is not loaded for " << prefix + "" + ".bias"; - } - - void merge_loaded_weights() { - for (int idx = 0; idx < layers_.size(); ++idx) { - layers_[idx]->merge_loaded_weights(); - } - } -#endif - private: int hidden_size_ = 0; int num_heads_ = 0; @@ -713,11 +672,6 @@ class Qwen3_VLForConditionalGenerationImpl : public torch::nn::Module { state_dict->get_dict_with_prefix("model.visual.")); } -#if defined(USE_NPU) - // verify - visual_->verify_loaded_weights("model.visual."); - visual_->merge_loaded_weights(); -#endif if (!model_args_.image_embedding_mode()) { language_model_->load_model(std::move(loader), "model.language_model."); } diff --git a/xllm/models/vlm/qwen3_vl_moe.h b/xllm/models/vlm/qwen3_vl_moe.h index 8be8b5811..0582c3df5 100644 --- a/xllm/models/vlm/qwen3_vl_moe.h +++ b/xllm/models/vlm/qwen3_vl_moe.h @@ -15,16 +15,10 @@ limitations under the License. #pragma once -#if defined(USE_NPU) -#include - -#include "xllm_kernels/core/include/atb_speed/log.h" -#endif #include #include #include -#include #include #include "core/framework/kv_cache/kv_cache.h" @@ -114,11 +108,7 @@ class Qwen3_VLMoeForConditionalGenerationImpl : public torch::nn::Module { visual_->load_state_dict( state_dict->get_dict_with_prefix("model.visual.")); } -#if defined(USE_NPU) - // verify - visual_->verify_loaded_weights("model.visual."); - visual_->merge_loaded_weights(); -#endif + if (!model_args_.image_embedding_mode()) { language_model_->load_model(std::move(loader), "model.language_model."); }