Skip to content

Commit d1e10c7

Browse files
committed
refactor: redesign wrapper for NPU fused_layernorm operator.
1 parent 0623eab commit d1e10c7

File tree

4 files changed

+9
-12
lines changed

4 files changed

+9
-12
lines changed

xllm/core/kernels/ops_api.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ void fused_layernorm(FusedLayerNormParams& params) {
283283
params.store_output_before_norm,
284284
params.store_output_after_norm,
285285
params.dynamic_quant);
286+
#elif defined(USE_NPU)
287+
params.output = npu::fused_layernorm(
288+
params.input, params.weight, params.eps, params.mode);
286289
#elif defined(USE_CUDA)
287290
if (params.residual.has_value()) {
288291
cuda::fused_add_rms_norm(
@@ -306,15 +309,6 @@ void fused_layernorm(FusedLayerNormParams& params) {
306309
#endif
307310
}
308311

309-
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params) {
310-
#if defined(USE_NPU)
311-
return npu::fused_layernorm(
312-
params.input, params.weight, params.eps, params.mode);
313-
#else
314-
LOG(FATAL) << "fused_layernorm not implemented";
315-
#endif
316-
}
317-
318312
torch::Tensor matmul(MatmulParams& params) {
319313
#if defined(USE_MLU)
320314
return mlu::matmul(

xllm/core/kernels/ops_api.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ void batch_decode(AttentionParams& params);
3636

3737
void fused_layernorm(FusedLayerNormParams& params);
3838

39-
torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params);
40-
4139
torch::Tensor matmul(MatmulParams& params);
4240

4341
torch::Tensor group_gemm(GroupGemmParams& params);

xllm/core/layers/common/dense_mlp.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <glog/logging.h>
1919

2020
#include "kernels/ops_api.h"
21+
#include "platform/device.h"
2122

2223
namespace xllm {
2324
namespace layer {

xllm/core/layers/common/rms_norm.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <glog/logging.h>
1919

2020
#include "kernels/ops_api.h"
21+
#include "platform/device.h"
2122

2223
namespace xllm {
2324
namespace layer {
@@ -40,7 +41,10 @@ RMSNormImpl::RMSNormImpl(const ModelContext& context)
4041
context.get_tensor_options()) {}
4142

4243
torch::Tensor RMSNormImpl::forward(torch::Tensor& input) {
43-
auto output = torch::empty_like(input);
44+
torch::Tensor output;
45+
if (Device::type_str() != "npu") {
46+
output = torch::empty_like(input);
47+
}
4448
return forward_output(input, output);
4549
}
4650

0 commit comments

Comments
 (0)