Skip to content

Commit 672e63c

Browse files
authored
bugfix: fix bug caused by atten_mask nullptr on mlu device. (#513)
Signed-off-by: pengtao.156 <pengtao.156@jd.com>
1 parent 8a2110c commit 672e63c

17 files changed

+98
-93
lines changed

xllm/core/layers/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ cc_library(
2121
NAME
2222
attention_mask
2323
HDRS
24-
common/attention_mask_impl.h
24+
common/attention_mask.h
2525
SRCS
26-
common/attention_mask_impl.cpp
26+
common/attention_mask.cpp
2727
DEPS
2828
:state_dict
2929
:block

xllm/core/layers/common/attention_mask_impl.cpp renamed to xllm/core/layers/common/attention_mask.cpp

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "attention_mask_impl.h"
16+
#include "attention_mask.h"
1717

1818
namespace xllm {
1919
namespace layer {
2020

21-
AttentionMaskImpl::AttentionMaskImpl(at::Device device,
22-
torch::Dtype dtype,
23-
float mask_value) {
21+
AttentionMask::AttentionMask(at::Device device,
22+
torch::Dtype dtype,
23+
float mask_value) {
2424
int max_seq_len = 128;
2525
seq_len_cached_ = max_seq_len;
2626
auto bias_cache =
@@ -37,25 +37,24 @@ AttentionMaskImpl::AttentionMaskImpl(at::Device device,
3737
.to(device);
3838
}
3939

40-
torch::Tensor AttentionMaskImpl::get_decode_attn_mask(
41-
torch::Tensor input_lengths,
42-
int64_t max_s,
43-
torch::Dtype dtype,
44-
torch::Device device) {
40+
torch::Tensor AttentionMask::get_decode_attn_mask(torch::Tensor input_lengths,
41+
int64_t max_s,
42+
torch::Dtype dtype,
43+
torch::Device device) {
4544
update_attn_cache(dtype, device, max_s);
4645
return atten_mask_cache_.index_select(0, input_lengths).view({-1, 1, max_s});
4746
}
4847

49-
torch::Tensor AttentionMaskImpl::get_attn_mask(int64_t max_s,
50-
torch::Dtype dtype,
51-
torch::Device device) {
48+
torch::Tensor AttentionMask::get_attn_mask(int64_t max_s,
49+
torch::Dtype dtype,
50+
torch::Device device) {
5251
update_attn_cache(dtype, device, max_s);
5352
return atten_mask_cache_.slice(0, 0, max_s).slice(1, 0, max_s);
5453
}
5554

56-
torch::Tensor AttentionMaskImpl::gen_free_mask(int32_t q_len,
57-
torch::Dtype dtype,
58-
torch::Device device) {
55+
torch::Tensor AttentionMask::gen_free_mask(int32_t q_len,
56+
torch::Dtype dtype,
57+
torch::Device device) {
5958
float pre_mask_factor = -10000.0f;
6059
if (dtype == torch::kBFloat16) {
6160
pre_mask_factor = 1.0f;
@@ -68,11 +67,11 @@ torch::Tensor AttentionMaskImpl::gen_free_mask(int32_t q_len,
6867
return mask_free;
6968
}
7069

71-
torch::Tensor AttentionMaskImpl::gen_append_mask(int32_t q_len,
72-
int32_t kv_len,
73-
int32_t max_kv_len,
74-
torch::Dtype dtype,
75-
torch::Device device) {
70+
torch::Tensor AttentionMask::gen_append_mask(int32_t q_len,
71+
int32_t kv_len,
72+
int32_t max_kv_len,
73+
torch::Dtype dtype,
74+
torch::Device device) {
7675
int diagonal = kv_len - q_len;
7776
auto options = torch::TensorOptions().dtype(torch::kBool).device(device);
7877
auto bias = torch::tril(torch::ones({q_len, max_kv_len}, options), diagonal);
@@ -84,9 +83,9 @@ torch::Tensor AttentionMaskImpl::gen_append_mask(int32_t q_len,
8483
return mask;
8584
}
8685

87-
void AttentionMaskImpl::update_attn_cache(torch::Dtype dtype,
88-
torch::Device device,
89-
int64_t seqlen) {
86+
void AttentionMask::update_attn_cache(torch::Dtype dtype,
87+
torch::Device device,
88+
int64_t seqlen) {
9089
if (seqlen > seq_len_cached_ || atten_mask_cache_.dtype() != dtype) {
9190
seq_len_cached_ = seqlen;
9291

xllm/core/layers/common/attention_mask_impl.h renamed to xllm/core/layers/common/attention_mask.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ limitations under the License.
1919
namespace xllm {
2020
namespace layer {
2121

22-
class AttentionMaskImpl : public torch::nn::Module {
22+
class AttentionMask : public torch::nn::Module {
2323
public:
24-
AttentionMaskImpl() = default;
24+
AttentionMask() = default;
2525

26-
explicit AttentionMaskImpl(at::Device device,
27-
torch::Dtype dtype,
28-
float mask_value = -9984);
26+
explicit AttentionMask(at::Device device,
27+
torch::Dtype dtype,
28+
float mask_value = -9984);
2929

3030
torch::Tensor get_decode_attn_mask(torch::Tensor input_lengths,
3131
int64_t max_s,
@@ -55,7 +55,6 @@ class AttentionMaskImpl : public torch::nn::Module {
5555
float mask_value_;
5656
at::Tensor atten_mask_cache_;
5757
};
58-
TORCH_MODULE(AttentionMask);
5958

6059
} // namespace layer
6160
} // namespace xllm

xllm/core/layers/common/rotary_embedding.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ MRotaryEmbeddingImpl::MRotaryEmbeddingImpl(
8888
rope_theta,
8989
interleaved,
9090
options),
91-
interleaved_(interleaved),
9291
mrope_section_(rope_scaling_mrope_section) {
9392
mrope_cu_seq_lens_ = torch::zeros(2, torch::kInt32).to(options.device());
9493
}

xllm/core/layers/common/rotary_embedding.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ class RotaryEmbeddingImpl : public torch::nn::Module {
5252

5353
torch::Tensor get_cos_sin_cache() { return cos_sin_cache_; }
5454

55-
private:
55+
protected:
5656
bool interleaved_;
57+
58+
private:
5759
torch::Tensor sin_;
5860
torch::Tensor cos_;
5961
torch::Tensor cos_sin_cache_;

xllm/core/layers/config.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,15 @@ REGISTER_NOT_IMPLEMENTED_CLASS(SiglipEncoderLayerImpl);
112112
#include "npu/npu_glm4_decoder_layer_impl.h"
113113
#else
114114
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4DecoderLayerImpl);
115-
#endif
115+
#endif
116+
117+
#if defined(USE_NPU)
118+
#include "npu/npu_glm4_vision_encoder_layer_impl.h"
119+
namespace xllm {
120+
namespace layer {
121+
using Glm4VisionEncoderLayerImpl = NpuGlm4VisionEncoderLayerImpl;
122+
}
123+
} // namespace xllm
124+
#else
125+
REGISTER_NOT_IMPLEMENTED_CLASS(Glm4VisionEncoderLayerImpl);
126+
#endif

xllm/core/layers/glm4_vision_encode_layer.h

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,20 @@ limitations under the License.
1515

1616
#pragma once
1717

18-
#if defined(USE_NPU)
19-
#include "npu/npu_glm4_vision_encoder_layer_impl.h"
20-
#endif
18+
#include "config.h"
2119

2220
namespace xllm {
2321
namespace layer {
2422

25-
#if defined(USE_NPU)
2623
class Glm4VisionEncoderLayer
27-
: public torch::nn::ModuleHolder<NpuGlm4VisionEncoderLayerImpl> {
24+
: public torch::nn::ModuleHolder<Glm4VisionEncoderLayerImpl> {
2825
public:
29-
using torch::nn::ModuleHolder<NpuGlm4VisionEncoderLayerImpl>::ModuleHolder;
30-
using Impl __attribute__((__unused__)) = NpuGlm4VisionEncoderLayerImpl;
26+
using torch::nn::ModuleHolder<Glm4VisionEncoderLayerImpl>::ModuleHolder;
27+
using Impl __attribute__((__unused__)) = Glm4VisionEncoderLayerImpl;
3128

3229
Glm4VisionEncoderLayer(const ModelContext& context)
33-
: ModuleHolder(std::make_shared<NpuGlm4VisionEncoderLayerImpl>(context)) {
34-
}
30+
: ModuleHolder(std::make_shared<Glm4VisionEncoderLayerImpl>(context)) {}
3531
};
36-
#endif
3732

3833
} // namespace layer
3934
} // namespace xllm

xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ limitations under the License.
2121
#include <map>
2222

2323
#include "common/global_flags.h"
24-
#include "core/layers/common/attention_mask_impl.h"
24+
#include "core/layers/common/attention_mask.h"
2525
#include "loader/llama_decoder_loader.h"
2626
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
2727
#include "torch_npu/csrc/core/npu/NPUException.h"

xllm/models/llm/deepseek_v2.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ limitations under the License.
2727
#include "core/framework/model/model_input_params.h"
2828
#include "core/framework/model/npu_dp_ep_padding.h"
2929
#include "core/framework/model_context.h"
30-
#include "core/layers/common/attention_mask_impl.h"
30+
#include "core/layers/common/attention_mask.h"
3131
#include "core/layers/deepseek_v2_decoder_layer.h"
3232
#include "core/layers/lm_head.h"
3333
#include "core/layers/npu/npu_rms_norm_impl.h"
@@ -159,9 +159,9 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
159159

160160
torch::Tensor attn_mask;
161161
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
162-
attn_mask = attn_mask_->get_attn_mask(128, dtype_, device_);
162+
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
163163
} else {
164-
attn_mask = attn_mask_->gen_free_mask(
164+
attn_mask = attn_mask_.gen_free_mask(
165165
num_speculative_tokens_ + 1, dtype_, device_);
166166
}
167167

@@ -251,7 +251,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
251251
layer::WordEmbedding embed_tokens_{nullptr};
252252
std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr};
253253
layer::PosEmbedding atb_pos_emb_{nullptr};
254-
layer::AttentionMask attn_mask_{nullptr};
254+
layer::AttentionMask attn_mask_;
255255
layer::RMSNorm norm_{nullptr};
256256
};
257257
TORCH_MODULE(DeepseekV2Model);

xllm/models/llm/deepseek_v2_mtp.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ limitations under the License.
2525
#include "core/framework/model/model_input_params.h"
2626
#include "core/framework/model/npu_dp_ep_padding.h"
2727
#include "core/framework/model_context.h"
28-
#include "core/layers/common/attention_mask_impl.h"
28+
#include "core/layers/common/attention_mask.h"
2929
#include "core/layers/deepseek_v2_decoder_layer.h"
3030
#include "core/layers/lm_head.h"
3131
#include "core/layers/npu/npu_column_parallel_linear_impl.h"
@@ -122,7 +122,7 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module {
122122
auto cos_pos = cos_sin_chunks[0].contiguous();
123123
auto sin_pos = cos_sin_chunks[1].contiguous();
124124

125-
auto attn_mask = attn_mask_->get_attn_mask(
125+
auto attn_mask = attn_mask_.get_attn_mask(
126126
128, cos_pos.dtype().toScalarType(), cos_pos.device());
127127
for (size_t i = 0; i < layers_.size(); i++) {
128128
aclrtEvent* event = nullptr;
@@ -205,7 +205,7 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module {
205205
layer::WordEmbedding embed_tokens_{nullptr};
206206
std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr};
207207
layer::PosEmbedding atb_pos_emb_{nullptr};
208-
layer::AttentionMask attn_mask_{nullptr};
208+
layer::AttentionMask attn_mask_;
209209
layer::ColumnParallelLinear eh_proj_{nullptr};
210210
layer::RMSNorm enorm_{nullptr};
211211
layer::RMSNorm hnorm_{nullptr};

0 commit comments

Comments
 (0)