Skip to content

Commit efb36d9

Browse files
committed
feat: integrate add_rms_norm interface for NPU backend.
1 parent 6964967 commit efb36d9

File tree

5 files changed

+26
-5
lines changed

5 files changed

+26
-5
lines changed

xllm/core/kernels/npu/attention.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ void batch_prefill(const torch::Tensor& query,
3333
const torch::Tensor& seq_len,
3434
float scale,
3535
torch::Tensor& output) {
36-
auto num_heads = query.size(-2);
37-
auto num_kv_heads = key.size(-2);
36+
int64_t num_heads = query.size(-2);
37+
int64_t num_kv_heads = key.size(-2);
3838
atb::_npu_flash_attention(
3939
query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output);
4040
}

xllm/core/kernels/npu/fused_layernorm.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,12 @@ torch::Tensor rms_norm(const torch::Tensor& input,
3333
return normalized_input;
3434
}
3535

36+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> add_rms_norm(
37+
const torch::Tensor& x1,
38+
const torch::Tensor& x2,
39+
const torch::Tensor& gamma,
40+
double epsilon) {
41+
return at_npu::native::custom_ops::npu_add_rms_norm(x1, x2, gamma, epsilon);
42+
}
43+
3644
} // namespace xllm::kernel::npu

xllm/core/kernels/npu/npu_ops_api.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#include <torch/torch.h>
1818

1919
#include <optional>
20+
#include <tuple>
2021

2122
#include "custom_functions_npu/atb_common.h"
2223

@@ -55,6 +56,12 @@ torch::Tensor rms_norm(const torch::Tensor& input,
5556
double eps,
5657
const std::string& mode);
5758

59+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> add_rms_norm(
60+
const torch::Tensor& x1,
61+
const torch::Tensor& x2,
62+
const torch::Tensor& gamma,
63+
double epsilon);
64+
5865
void apply_rotary(torch::Tensor& q,
5966
torch::Tensor& k,
6067
const torch::Tensor& cos_sin_cache,

xllm/core/kernels/npu/rope.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ void apply_rotary(torch::Tensor& q,
2525
const torch::Tensor& cos_sin_cache,
2626
const torch::Tensor& positions) {
2727
auto cos_sin = cos_sin_cache.index_select(0, positions);
28-
auto last_dim = cos_sin.size(-1);
28+
int64_t last_dim = cos_sin.size(-1);
2929
auto cos_sin_vec = cos_sin.view({-1, 2, last_dim / 2})
3030
.repeat({1, 1, 2})
3131
.chunk(2, /*dim=*/-2);

xllm/core/kernels/ops_api.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,14 @@ void fused_layernorm(FusedLayerNormParams& params) {
284284
params.store_output_after_norm,
285285
params.dynamic_quant);
286286
#elif defined(USE_NPU)
287-
params.output =
288-
npu::rms_norm(params.input, params.weight, params.eps, params.mode);
287+
if (params.residual.has_value()) {
288+
std::tie(params.output, std::ignore, params.residual_out) =
289+
npu::add_rms_norm(
290+
params.input, params.residual.value(), params.weight, params.eps);
291+
} else {
292+
params.output =
293+
npu::rms_norm(params.input, params.weight, params.eps, params.mode);
294+
}
289295
#elif defined(USE_CUDA)
290296
if (params.residual.has_value()) {
291297
cuda::fused_add_rms_norm(

0 commit comments

Comments
 (0)