Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xllm/core/layers/common/indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "../mlu/attention.h"
#elif defined(USE_CUDA)
#include "../cuda/attention.h"
#endif #include "framework/kv_cache/kv_cache.h"
#endif
#include "framework/model/model_input_params.h"
#include "framework/parallel_state/parallel_args.h"
#include "framework/quant_args.h"
Expand Down
1 change: 0 additions & 1 deletion xllm/models/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
include(cc_library)

# Define the library
cc_library(
NAME
models
Expand Down
277 changes: 54 additions & 223 deletions xllm/models/llm/deepseek_v2.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <gflags/gflags.h>
#include <torch/torch.h>

#include <boost/algorithm/string.hpp>
#include <string>
#include <vector>

#include "core/common/global_flags.h"
#include "core/framework/kv_cache/kv_cache.h"
#include "core/framework/model/model_input_params.h"
#include "core/framework/model/npu_dp_ep_padding.h"
#include "core/framework/model_context.h"
#include "core/layers/attention_mask.h"
#include "core/layers/deepseek_v2_decoder_layer.h"
#include "core/layers/lm_head.h"
#include "core/layers/pos_embedding.h"
#include "core/layers/rms_norm.h"
#include "core/layers/rotary_embedding.h"
#include "core/layers/word_embedding.h"
#include "models/model_registry.h"
#include "llm_model_base.h"

// DeepSeek v2 compatible with huggingface weights
// ref to:
// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
Expand All @@ -46,47 +33,29 @@ using ISlice = torch::indexing::Slice;

class DeepseekV2DecoderLayerImpl : public torch::nn::Module {
public:
DeepseekV2DecoderLayerImpl(const ModelContext& context,
const int32_t i,
const float sm_scale) {
DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) {
// register submodules
decoder_layer_ = register_module(
"decoder_layer", layer::DeepseekV2DecoderLayer(context, i, sm_scale));
decoder_layer_ = register_module("decoder_layer",
layer::DeepseekV2DecoderLayer(context, i));
}

torch::Tensor forward(torch::Tensor& x,
torch::Tensor& cos_pos,
torch::Tensor& sin_pos,
torch::Tensor& attn_mask,
torch::Tensor& positions,
const layer::AttentionMetadata& attn_metadata,
KVCache& kv_cache,
const ModelInputParams& input_params,
aclrtEvent* event,
std::atomic<bool>* event_flag) {
return decoder_layer_(x,
cos_pos,
sin_pos,
attn_mask,
kv_cache,
input_params,
event,
event_flag);
const ModelInputParams& input_params) {
return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params);
}

void load_state_dict(const StateDict& state_dict) {
decoder_layer_->load_state_dict(state_dict);
}

void verify_loaded_weights(const std::string& prefix) const {
decoder_layer_->verify_loaded_weights(prefix);
}

void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); }

void prepare_expert_weight(const std::vector<int32_t>& expert_list) {
decoder_layer_->prepare_expert_weight(expert_list);
virtual void prepare_expert_weight(int32_t layer_id,
const std::vector<int32_t>& expert_ids) {
return;
}

void update_expert_weight() { decoder_layer_->update_expert_weight(); }
virtual void update_expert_weight(int32_t layer_id) { return; }

private:
layer::DeepseekV2DecoderLayer decoder_layer_{nullptr};
Expand All @@ -95,114 +64,71 @@ TORCH_MODULE(DeepseekV2DecoderLayer);

class DeepseekV2ModelImpl : public torch::nn::Module {
public:
DeepseekV2ModelImpl(const ModelContext& context)
: device_(context.get_tensor_options().device()) {
DeepseekV2ModelImpl(const ModelContext& context) {
auto options = context.get_tensor_options();
auto model_args = context.get_model_args();
auto parallel_args = context.get_parallel_args();

blocks_ = register_module("layers", torch::nn::ModuleList());
layers_.reserve(model_args.n_layers());

// register submodules
device_ = options.device();
dtype_ = options.dtype().toScalarType();
num_speculative_tokens_ = model_args.num_speculative_tokens();

// rotary positional embedding
auto inv_freq = rotary::apply_deepseek_yarn_rope_scaling(
model_args.rope_scaling_factor(),
model_args.rope_extrapolation_factor(),
model_args.rope_scaling_beta_fast(),
model_args.rope_scaling_beta_slow(),
model_args.rotary_dim(),
model_args.rope_theta(),
model_args.rope_scaling_original_max_position_embeddings());
embed_tokens_ =
register_module("embed_tokens", layer::WordEmbedding(context));
float sm_scale = 1.0f;
pos_emb_ = create_rotary_embedding(model_args,
model_args.rotary_dim(),
inv_freq,
/*interleaved=*/false,
sm_scale,
options);
atb_pos_emb_ = layer::PosEmbedding(context);

max_seq_len_ = model_args.max_position_embeddings();
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
attn_mask_ = layer::AttentionMask(options.device(),
options.dtype().toScalarType(),
/*mask_value=*/mask_value);
// MTP is not support for now
if (num_speculative_tokens_ > 0) {
LOG(FATAL) << "DeepSeek MTP on MLU is not support for now";
}

embed_tokens_ =
register_module("embed_tokens",
layer::WordEmbedding(model_args.vocab_size(),
model_args.hidden_size(),
context.get_parallel_args(),
options));
norm_ = register_module(
"norm",
layer::RmsNorm(
model_args.hidden_size(), model_args.rms_norm_eps(), options));

// create decoder layers
for (int32_t i = 0; i < model_args.n_layers(); ++i) {
auto block = DeepseekV2DecoderLayer(context, i, sm_scale);
auto block = DeepseekV2DecoderLayer(context, i);
layers_.push_back(block);
blocks_->push_back(block);
}

norm_ = register_module("norm", layer::RmsNorm(context));
// dp_size_=4;
dp_size_ = parallel_args.dp_size();
std::vector<int64_t> indices;
dp_local_tp_size_ = parallel_args.world_size() / dp_size_;
dp_rank_ = parallel_args.rank() / dp_local_tp_size_;
rank_ = parallel_args.rank();
mapping_data_ = parallel_args.mapping_data();
num_experts_per_tok_ = model_args.num_experts_per_tok();
for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) {
indices.push_back(i);
}
}

torch::Tensor forward(torch::Tensor tokens,
torch::Tensor positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
if (dp_size_ > 1) {
if (tokens.sizes() == 0) {
tokens = torch::tensor({1}).to(torch::kInt32).to(device_);
positions = torch::tensor({0}).to(torch::kInt32).to(device_);
}
}

auto h = embed_tokens_(tokens, 0);
auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0);
auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1);
auto cos_pos = cos_sin_chunks[0].contiguous();
auto sin_pos = cos_sin_chunks[1].contiguous();

torch::Tensor attn_mask;
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
} else {
attn_mask = attn_mask_.gen_free_mask(
num_speculative_tokens_ + 1, dtype_, device_);
}

torch::Tensor forward_native(torch::Tensor tokens,
torch::Tensor positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
bool is_prefill = input_params.q_max_seq_len > 1;
auto attn_metadata =
layer::AttentionMetadata::build(input_params, is_prefill);
torch::Tensor h = embed_tokens_(tokens);
for (size_t i = 0; i < layers_.size(); i++) {
aclrtEvent* event = nullptr;
std::atomic<bool>* event_flag = nullptr;
if (input_params.layer_synchronizer != nullptr) {
event = input_params.layer_synchronizer->get_event(i);
event_flag = input_params.layer_synchronizer->get_event_flag(i);
}
if (input_params.layer_wise_load_synchronizer != nullptr) {
if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) {
return torch::Tensor();
}
}

auto& layer = layers_[i];
layer(h,
cos_pos,
sin_pos,
attn_mask,
kv_caches[i],
input_params,
event,
event_flag);
h = layer(h, positions, attn_metadata, kv_caches[i], input_params);
}
return norm_(h, 0);
return norm_(h);
}

// Provide batched signature to satisfy callers that pass vectors
torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
return forward_native(tokens, positions, kv_caches, input_params);
}

// load the weight from the checkpoint
Expand All @@ -217,32 +143,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
norm_->load_state_dict(state_dict.get_dict_with_prefix("norm."));
}

void verify_loaded_weights(const std::string& prefix) const {
embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
for (int i = 0; i < layers_.size(); i++) {
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
".");
}
norm_->verify_loaded_weights(prefix + "norm.");
}

void merge_loaded_weights() {
embed_tokens_->merge_loaded_weights();
for (int i = 0; i < layers_.size(); i++) {
layers_[i]->merge_loaded_weights();
}
norm_->merge_loaded_weights();
}

void prepare_expert_weight(int32_t layer_id,
const std::vector<int32_t>& expert_ids) {
layers_[layer_id]->prepare_expert_weight(expert_ids);
}

void update_expert_weight(int32_t layer_id) {
layers_[layer_id]->update_expert_weight();
}

layer::WordEmbedding get_word_embedding() { return embed_tokens_; }

void set_word_embedding(layer::WordEmbedding& word_embedding) {
Expand All @@ -252,90 +152,21 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
private:
torch::nn::ModuleList blocks_{nullptr};
std::vector<DeepseekV2DecoderLayer> layers_;
int32_t max_seq_len_ = 0;
int32_t dp_rank_;
int32_t rank_;
int32_t dp_size_;
int32_t dp_local_tp_size_;
nlohmann::json mapping_data_;
int32_t num_experts_per_tok_;
int32_t num_speculative_tokens_ = 0;
at::Device device_;
torch::Dtype dtype_;
layer::WordEmbedding embed_tokens_{nullptr};
std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr};
layer::PosEmbedding atb_pos_emb_{nullptr};
layer::AttentionMask attn_mask_;
layer::RmsNorm norm_{nullptr};
};
TORCH_MODULE(DeepseekV2Model);

class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
class DeepseekV2ForCausalLMImpl
: public LlmForCausalLMImplBase<DeepseekV2Model> {
public:
DeepseekV2ForCausalLMImpl(const ModelContext& context) {
model_ = register_module("model", DeepseekV2Model(context));
lm_head_ = register_module("lm_head", layer::LmHead(context));
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
}

// tokens: [num_tokens]
// positions: [num_tokens] token pos in the sequence
// returns: [num_tokens, hidden_size]
torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
return model_(tokens, positions, kv_caches, input_params);
}

// hidden_states: [num_tokens, hidden_size]
// seleted_idxes: [num_tokens]
// returns: [num_tokens, vocab_size]
torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return lm_head_(hidden_states, seleted_idxes, 0);
}

void load_model(std::unique_ptr<ModelLoader> loader) {
for (const auto& state_dict : loader->get_state_dicts()) {
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head."));
}

// verify
model_->verify_loaded_weights("model.");
lm_head_->verify_loaded_weights("lm_head.");

model_->merge_loaded_weights();
lm_head_->merge_loaded_weights();
}

void prepare_expert_weight(int32_t layer_id,
const std::vector<int32_t>& expert_ids) {
model_->prepare_expert_weight(layer_id + first_k_dense_replace_,
expert_ids);
}

void update_expert_weight(int32_t layer_id) {
model_->update_expert_weight(layer_id + first_k_dense_replace_);
}

layer::LmHead get_lm_head() { return lm_head_; }

void set_lm_head(layer::LmHead& head) { lm_head_ = head; }

layer::WordEmbedding get_word_embedding() {
return model_->get_word_embedding();
}

void set_word_embedding(layer::WordEmbedding& word_embedding) {
model_->set_word_embedding(word_embedding);
}

private:
DeepseekV2Model model_{nullptr};
layer::LmHead lm_head_{nullptr};
int32_t first_k_dense_replace_;
DeepseekV2ForCausalLMImpl(const ModelContext& context)
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
};
TORCH_MODULE(DeepseekV2ForCausalLM);

Expand Down
Loading
Loading