From 18d6699b6efa6fb9916c65278fd22ac0ced8f082 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 23 Jan 2026 01:32:36 +0000 Subject: [PATCH] issue/791 fix add_rmsnorm api and rmsnorm module --- include/infinicore/nn/rmsnorm.hpp | 23 ++++- include/infinicore/ops/add_rms_norm.hpp | 14 ++-- include/infiniop/ops/add_rms_norm.h | 6 +- python/infinicore/__init__.py | 2 +- python/infinicore/ops/add_rms_norm.py | 29 ++----- src/infinicore/nn/rmsnorm.cc | 18 ++++ .../ops/add_rms_norm/add_rms_norm.cc | 24 +++--- .../ops/add_rms_norm/add_rms_norm_infiniop.cc | 71 ++++++++-------- src/infiniop/ops/add_rms_norm/add_rms_norm.h | 6 +- .../ops/add_rms_norm/cpu/add_rms_norm_cpu.cc | 30 +++---- src/infiniop/ops/add_rms_norm/info.h | 6 +- .../nvidia/add_rms_norm_nvidia.cu | 10 +-- src/infiniop/ops/add_rms_norm/operator.cc | 12 +-- test/infinicore/ops/add_rms_norm.py | 84 ++++++++++++------- test/infiniop/add_rms_norm.py | 40 +++++++-- test/infiniop/libinfiniop/op_register.py | 2 + 16 files changed, 225 insertions(+), 152 deletions(-) diff --git a/include/infinicore/nn/rmsnorm.hpp b/include/infinicore/nn/rmsnorm.hpp index 212b2a6e4..5891819eb 100644 --- a/include/infinicore/nn/rmsnorm.hpp +++ b/include/infinicore/nn/rmsnorm.hpp @@ -1,7 +1,7 @@ #pragma once -#include "module.hpp" #include "../ops.hpp" +#include "module.hpp" namespace infinicore::nn { @@ -57,6 +57,21 @@ class RMSNorm : public Module { */ Tensor forward(const Tensor &x) const; + /** + * @brief Forward pass: apply RMSNorm in-place with residual + * + * @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions. + * Will be modified in-place to the normalized output. + * @param residual Residual tensor to add to input before normalization. + * Will be modified in-place to the sum of input and residual. + * + * The normalization is applied over the last dimension. + * For example: + * Input: [batch, seq_len, hidden_size] -> normalize over hidden_size + * Input: [batch, hidden_size] -> normalize over hidden_size + */ + void forward_inplace(Tensor &x, Tensor &residual) const; + // Module information size_t normalized_shape() const { return normalized_shape_; } double eps() const { return eps_; } @@ -73,9 +88,9 @@ class RMSNorm : public Module { INFINICORE_NN_PARAMETER(weight); private: - size_t normalized_shape_; // Size of the feature dimension - double eps_; // Epsilon for numerical stability - DataType dtype_; // Data type for weight + size_t normalized_shape_; // Size of the feature dimension + double eps_; // Epsilon for numerical stability + DataType dtype_; // Data type for weight }; } // namespace infinicore::nn diff --git a/include/infinicore/ops/add_rms_norm.hpp b/include/infinicore/ops/add_rms_norm.hpp index e8a955a3c..50064e0a4 100644 --- a/include/infinicore/ops/add_rms_norm.hpp +++ b/include/infinicore/ops/add_rms_norm.hpp @@ -5,16 +5,14 @@ #include namespace infinicore::op { -class AddRMSNorm { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float); - static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float); // Fused Add and RMS Normalization // Returns: (normalized_result, add_result) // The add_result can be used as residual for subsequent layers -std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); -void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); +std::pair add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f); +void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f); +// Fused Add and RMS Normalization (inplace) +// normalized_result wil be stored in input, add_result will be stored in residual +void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f); } // namespace infinicore::op diff --git a/include/infiniop/ops/add_rms_norm.h b/include/infiniop/ops/add_rms_norm.h index 7742c1343..52cd096a6 100644 --- a/include/infiniop/ops/add_rms_norm.h +++ b/include/infiniop/ops/add_rms_norm.h @@ -9,11 +9,11 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( infiniopHandle_t handle, infiniopAddRMSNormDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc); + float epsilon); __C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); @@ -21,10 +21,10 @@ __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t de void *workspace, size_t workspace_size, void *y, + void *residual_out, const void *a, const void *b, const void *weight, - void *residual_out, void *stream); __C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..52a269ce5 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -43,7 +43,7 @@ uint8, ) from infinicore.ops.add import add -from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ +from infinicore.ops.add_rms_norm import add_rms_norm from infinicore.ops.attention import attention from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul diff --git a/python/infinicore/ops/add_rms_norm.py b/python/infinicore/ops/add_rms_norm.py index 4ad347812..a5de7bd92 100644 --- a/python/infinicore/ops/add_rms_norm.py +++ b/python/infinicore/ops/add_rms_norm.py @@ -1,8 +1,8 @@ +import infinicore.tensor as tensor from infinicore.lib import _infinicore -from infinicore.tensor import Tensor -def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): +def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None): """ Fused Add and RMS Normalization. @@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): The add_result can be used as residual for subsequent layers. """ if out is None: - result = _infinicore.add_rms_norm( - a._underlying, b._underlying, weight._underlying, epsilon - ) - return (Tensor(result[0]), Tensor(result[1])) + out = tensor.empty(a.shape, dtype=a.dtype, device=a.device) + if residual is None: + residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device) - y, residual_out = out _infinicore.add_rms_norm_( - y._underlying, - residual_out._underlying, + out._underlying, + residual._underlying, a._underlying, b._underlying, weight._underlying, epsilon, ) - return (y, residual_out) - -def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5): - """In-place Fused Add and RMS Normalization.""" - _infinicore.add_rms_norm_( - y._underlying, - residual_out._underlying, - a._underlying, - b._underlying, - weight._underlying, - epsilon, - ) + return out, residual diff --git a/src/infinicore/nn/rmsnorm.cc b/src/infinicore/nn/rmsnorm.cc index a83c3a113..107dac44a 100644 --- a/src/infinicore/nn/rmsnorm.cc +++ b/src/infinicore/nn/rmsnorm.cc @@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const { return op::rms_norm(x, weight_, static_cast(eps_)); } +void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const { + if (!residual) { + residual = x; + x = op::rms_norm(x, weight_, static_cast(eps_)); + } else { + if (device_.getType() == Device::Type::CPU + || device_.getType() == Device::Type::NVIDIA + || device_.getType() == Device::Type::ILUVATAR + || device_.getType() == Device::Type::METAX + || device_.getType() == Device::Type::MOORE) { + op::add_rms_norm_inplace(x, residual, weight_, static_cast(eps_)); + } else { + op::add_(residual, x, residual); + op::rms_norm_(x, residual, weight_, static_cast(eps_)); + } + } +} + std::string RMSNorm::extra_repr() const { return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast(dtype_)) + ")"; } diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc index 650ce87e6..ccba62e21 100644 --- a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc @@ -4,26 +4,30 @@ namespace infinicore::op { -common::OpDispatcher &AddRMSNorm::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm); -void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { +AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon); } -std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { +void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon); +} + +std::pair add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); add_rms_norm_(y, residual_out, a, b, weight, epsilon); return std::make_pair(y, residual_out); } -void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { - AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); +void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { + AddRMSNorm::execute(out, residual, a, b, weight, epsilon); +} + +void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) { + add_rms_norm_(input, residual, input, residual, weight, epsilon); } } // namespace infinicore::op diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc index d6540a039..53d30a2c7 100644 --- a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc @@ -1,50 +1,53 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/add_rms_norm.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::add_rms_norm_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAddRMSNormDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, residual, a, b, weight; + float epsilon; +}; -void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { +void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) { size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, AddRMSNorm, + seed, y->desc(), residual_out->desc(), + a->desc(), b->desc(), weight->desc(), epsilon); + + INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor); - auto desc_opt = cache.get(seed); - infiniopAddRMSNormDescriptor_t desc = nullptr; + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(residual_out), + graph::GraphTensor(a), + graph::GraphTensor(b), + graph::GraphTensor(weight), + epsilon}; - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( - context::getInfiniopHandle(device), &desc, - y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return planned; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( - desc, workspace->data(), workspace_size, - y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), + planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - AddRMSNorm::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup); } // namespace infinicore::op::add_rms_norm_impl::infiniop diff --git a/src/infiniop/ops/add_rms_norm/add_rms_norm.h b/src/infiniop/ops/add_rms_norm/add_rms_norm.h index c5d63333d..76451e982 100644 --- a/src/infiniop/ops/add_rms_norm/add_rms_norm.h +++ b/src/infiniop/ops/add_rms_norm/add_rms_norm.h @@ -33,19 +33,19 @@ infiniopHandle_t handle, \ Descriptor **desc_ptr, \ infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t residual_out_desc, \ infiniopTensorDescriptor_t a_desc, \ infiniopTensorDescriptor_t b_desc, \ infiniopTensorDescriptor_t weight_desc, \ - float epsilon, \ - infiniopTensorDescriptor_t residual_out_desc); \ + float epsilon); \ \ infiniStatus_t calculate( \ void *workspace, size_t workspace_size, \ void *y, \ + void *residual_out, \ const void *a, \ const void *b, \ const void *weight, \ - void *residual_out, \ void *stream) const; \ }; \ } diff --git a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc index 5e7954b71..a3099c5c4 100644 --- a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc +++ b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc @@ -10,19 +10,19 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } template -infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) { +infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const T *w) { const size_t batch_size = info->shape[0]; const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; const size_t dim = info->dim(); @@ -61,7 +61,7 @@ infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T } template -infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) { +infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, T *residual_out, const T *a, const T *b, const Tw *w) { static_assert(std::is_same::value || std::is_same::value, "T must be fp16_t or bf16_t"); @@ -112,32 +112,32 @@ infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { if (_info.atype == INFINI_DTYPE_F16) { if (_info.wtype == INFINI_DTYPE_F16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight)); } else if (_info.wtype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight)); } else if (_info.wtype == INFINI_DTYPE_BF16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (fp16_t *)residual_out, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } } else if (_info.atype == INFINI_DTYPE_BF16) { if (_info.wtype == INFINI_DTYPE_BF16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight)); } else if (_info.wtype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight)); } else if (_info.wtype == INFINI_DTYPE_F16) { - CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out)); + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (bf16_t *)residual_out, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } } else if (_info.atype == INFINI_DTYPE_F32) { - CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out)); + CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (float *)residual_out, (const float *)a, (const float *)b, (const float *)weight)); } else if (_info.atype == INFINI_DTYPE_F64) { - CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out)); + CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (double *)residual_out, (const double *)a, (const double *)b, (const double *)weight)); } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/add_rms_norm/info.h b/src/infiniop/ops/add_rms_norm/info.h index abe1b5059..883aed343 100644 --- a/src/infiniop/ops/add_rms_norm/info.h +++ b/src/infiniop/ops/add_rms_norm/info.h @@ -16,9 +16,9 @@ class AddRMSNormInfo { float epsilon; std::vector shape; std::vector y_strides; + std::vector residual_out_strides; std::vector a_strides; std::vector b_strides; - std::vector residual_out_strides; bool has_residual_out; size_t ndim() const { return shape.size(); } @@ -26,11 +26,11 @@ class AddRMSNormInfo { static utils::Result create( infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { + float epsilon) { auto atype = y_desc->dtype(); auto wtype = weight_desc->dtype(); diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu index 03601205f..8fddf5958 100644 --- a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu @@ -49,12 +49,12 @@ infiniStatus_t Descriptor::create( infiniopHandle_t handle, Descriptor **desc_ptr, infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { - auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + float epsilon) { + auto result = AddRMSNormInfo::create(y_desc, residual_out_desc, a_desc, b_desc, weight_desc, epsilon); CHECK_RESULT(result); auto info = result.take(); @@ -122,8 +122,8 @@ infiniStatus_t launchKernel( infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) const { + void *y, void *residual_out, const void *a, const void *b, const void *weight, + void *stream) const { if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc index a856e5447..11c0aef99 100644 --- a/src/infiniop/ops/add_rms_norm/operator.cc +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -32,12 +32,12 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( infiniopHandle_t handle, infiniopAddRMSNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t residual_out_desc, infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t weight_desc, - float epsilon, - infiniopTensorDescriptor_t residual_out_desc) { + float epsilon) { #define CREATE(CASE, NAMESPACE) \ case CASE: \ @@ -45,11 +45,11 @@ __C infiniStatus_t infiniopCreateAddRMSNormDescriptor( handle, \ reinterpret_cast(desc_ptr), \ y_desc, \ + residual_out_desc, \ a_desc, \ b_desc, \ weight_desc, \ - epsilon, \ - residual_out_desc) + epsilon) switch (handle->device) { #ifdef ENABLE_CPU_API @@ -116,16 +116,16 @@ __C infiniStatus_t infiniopAddRMSNorm( void *workspace, size_t workspace_size, void *y, + void *residual_out, const void *a, const void *b, const void *weight, - void *residual_out, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream) + ->calculate(workspace, workspace_size, y, residual_out, a, b, weight, stream) switch (desc->device_type) { diff --git a/test/infinicore/ops/add_rms_norm.py b/test/infinicore/ops/add_rms_norm.py index 429d9df25..f6bf165a9 100644 --- a/test/infinicore/ops/add_rms_norm.py +++ b/test/infinicore/ops/add_rms_norm.py @@ -30,8 +30,24 @@ ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (2048, 8192, 1), + (2048, 8192, 1), + (2048, 8192, 1), + ), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + (16384, 4096, 1), + ), ] # Tolerance configuration @@ -87,12 +103,14 @@ def parse_test_cases(): y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) # Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result) - residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) + residual_out_spec = TensorSpec.from_tensor( + a_shape, a_strides, input_dtype + ) test_cases.append( TestCase( inputs=[a_spec, b_spec, w_spec], kwargs={"epsilon": _EPSILON}, - output_specs=[y_spec, residual_out_spec], # Two outputs + output_specs=None, # Two outputs comparison_target=None, tolerance=tolerance, output_count=2, # Two outputs: normalized_result and add_result @@ -101,19 +119,25 @@ def parse_test_cases(): ) # Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w)) - if y_supports_inplace: - residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) - test_cases.append( - TestCase( - inputs=[a_spec, b_spec, w_spec], - kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)}, - output_specs=[y_spec, residual_out_spec], # Two outputs - comparison_target="out", - tolerance=tolerance, - output_count=2, - description=f"AddRMSNorm - INPLACE(out)", - ) - ) + # if y_supports_inplace: + # residual_out_spec = TensorSpec.from_tensor( + # a_shape, a_strides, input_dtype + # ) + # test_cases.append( + # TestCase( + # inputs=[a_spec, b_spec, w_spec], + # kwargs={ + # "epsilon": _EPSILON, + # "out": y_spec, + # "residual": residual_out_spec, + # }, + # output_specs=[y_spec, residual_out_spec], # Two outputs + # comparison_target="out", + # tolerance=tolerance, + # output_count=2, + # description=f"AddRMSNorm - INPLACE(out)", + # ) + # ) return test_cases @@ -127,7 +151,9 @@ def __init__(self): def get_test_cases(self): return parse_test_cases() - def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + def torch_operator( + self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs + ): """PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)""" input_dtype = a.dtype @@ -144,21 +170,19 @@ def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): add_result = sum_tensor.to(input_dtype) if out is not None: - # For in-place operations, we need to handle the output tuple - if isinstance(out, (tuple, list)) and len(out) == 2: - out[0].copy_(normalized_result) - out[1].copy_(add_result) - return tuple(out) - else: - # Single output - just return normalized result for backward compatibility - out.copy_(normalized_result) - return out - + out.copy_(normalized_result) + if residual is not None: + residual.copy_(add_result) + return (normalized_result, add_result) - def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + def infinicore_operator( + self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs + ): """InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)""" - return infinicore.add_rms_norm(a, b, weight, epsilon, out=out) + return infinicore.add_rms_norm( + a, b, weight, epsilon, out=out, residual=residual + ) def main(): diff --git a/test/infiniop/add_rms_norm.py b/test/infiniop/add_rms_norm.py index 930314761..e3b4f9b64 100644 --- a/test/infiniop/add_rms_norm.py +++ b/test/infiniop/add_rms_norm.py @@ -32,8 +32,24 @@ ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), - ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (2048, 8192, 1), + (2048, 8192, 1), + (2048, 8192, 1), + ), + ( + (4, 4, 2048), + (4, 4, 2048), + (4, 4, 2048), + (2048,), + (16384, 4096, 1), + (16384, 4096, 1), + (16384, 4096, 1), + ), ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), ((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None), ] @@ -97,7 +113,9 @@ def test( w = TestTensor(w_shape, None, w_dtype, device) eps = 1e-6 - add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps) + add_rms_norm( + y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps + ) if sync is not None: sync() @@ -109,11 +127,11 @@ def test( handle, ctypes.byref(descriptor), y.descriptor, + residual_out.descriptor, a.descriptor, b.descriptor, w.descriptor, eps, - residual_out.descriptor, ) ) @@ -136,10 +154,10 @@ def lib_add_rms_norm(): workspace.data(), workspace_size.value, y.data(), + residual_out.data(), a.data(), b.data(), w.data(), - residual_out.data(), None, ) ) @@ -147,18 +165,22 @@ def lib_add_rms_norm(): lib_add_rms_norm() atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - + # Verify normalized result (y) if DEBUG: debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) - + # Verify add result (residual_out) - should be a + b - expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32) + expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to( + torch.float32 + ) expected_residual = expected_residual.to(a.torch_tensor().dtype) if DEBUG: debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) - assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) + assert torch.allclose( + residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol + ) # Profiling workflow if PROFILE: diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 618be2b05..7d6cf17e2 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -393,6 +393,7 @@ def add_rms_norm_(lib): infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, c_float, ] @@ -412,6 +413,7 @@ def add_rms_norm_(lib): c_void_p, c_void_p, c_void_p, + c_void_p, ] lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32