Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions include/infinicore/nn/rmsnorm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "module.hpp"
#include "../ops.hpp"
#include "module.hpp"

namespace infinicore::nn {

Expand Down Expand Up @@ -57,6 +57,21 @@ class RMSNorm : public Module {
*/
Tensor forward(const Tensor &x) const;

/**
* @brief Forward pass: apply RMSNorm in-place with residual
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions.
* Will be modified in-place to the normalized output.
* @param residual Residual tensor to add to input before normalization.
* Will be modified in-place to the sum of input and residual.
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
void forward_inplace(Tensor &x, Tensor &residual) const;

// Module information
size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; }
Expand All @@ -73,9 +88,9 @@ class RMSNorm : public Module {
INFINICORE_NN_PARAMETER(weight);

private:
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
};

} // namespace infinicore::nn
14 changes: 6 additions & 8 deletions include/infinicore/ops/add_rms_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
#include <utility>

namespace infinicore::op {
class AddRMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(AddRMSNorm, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &, float);

// Fused Add and RMS Normalization
// Returns: (normalized_result, add_result)
// The add_result can be used as residual for subsequent layers
std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f);
std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon = 1e-5f);
// Fused Add and RMS Normalization (inplace)
// normalized_result wil be stored in input, add_result will be stored in residual
void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon = 1e-5f);
} // namespace infinicore::op
6 changes: 3 additions & 3 deletions include/infiniop/ops/add_rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor(
infiniopHandle_t handle,
infiniopAddRMSNormDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t residual_out_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc,
infiniopTensorDescriptor_t weight_desc,
float epsilon,
infiniopTensorDescriptor_t residual_out_desc);
float epsilon);

__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
void *residual_out,
const void *a,
const void *b,
const void *weight,
void *residual_out,
void *stream);

__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc);
Expand Down
2 changes: 1 addition & 1 deletion python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
uint8,
)
from infinicore.ops.add import add
from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_
from infinicore.ops.add_rms_norm import add_rms_norm
from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul
Expand Down
29 changes: 8 additions & 21 deletions python/infinicore/ops/add_rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import infinicore.tensor as tensor
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None, residual=None):
"""
Fused Add and RMS Normalization.

Expand All @@ -18,30 +18,17 @@ def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None):
The add_result can be used as residual for subsequent layers.
"""
if out is None:
result = _infinicore.add_rms_norm(
a._underlying, b._underlying, weight._underlying, epsilon
)
return (Tensor(result[0]), Tensor(result[1]))
out = tensor.empty(a.shape, dtype=a.dtype, device=a.device)
if residual is None:
residual = tensor.empty(b.shape, dtype=b.dtype, device=b.device)

y, residual_out = out
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
out._underlying,
residual._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return (y, residual_out)


def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5):
"""In-place Fused Add and RMS Normalization."""
_infinicore.add_rms_norm_(
y._underlying,
residual_out._underlying,
a._underlying,
b._underlying,
weight._underlying,
epsilon,
)
return out, residual
18 changes: 18 additions & 0 deletions src/infinicore/nn/rmsnorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,24 @@ Tensor RMSNorm::forward(const Tensor &x) const {
return op::rms_norm(x, weight_, static_cast<float>(eps_));
}

void RMSNorm::forward_inplace(Tensor &x, Tensor &residual) const {
if (!residual) {
residual = x;
x = op::rms_norm(x, weight_, static_cast<float>(eps_));
} else {
if (device_.getType() == Device::Type::CPU
|| device_.getType() == Device::Type::NVIDIA
|| device_.getType() == Device::Type::ILUVATAR
|| device_.getType() == Device::Type::METAX
|| device_.getType() == Device::Type::MOORE) {
op::add_rms_norm_inplace(x, residual, weight_, static_cast<float>(eps_));
} else {
op::add_(residual, x, residual);
op::rms_norm_(x, residual, weight_, static_cast<float>(eps_));
}
}
}

std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
}
Expand Down
24 changes: 14 additions & 10 deletions src/infinicore/ops/add_rms_norm/add_rms_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@

namespace infinicore::op {

common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() {
static common::OpDispatcher<AddRMSNorm::schema> dispatcher_;
return dispatcher_;
};
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(AddRMSNorm);

void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::AddRMSNorm(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon);
INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, residual_out, a, b, weight, epsilon);
}

std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) {
void AddRMSNorm::execute(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(AddRMSNorm, y, residual_out, a, b, weight, epsilon);
}

std::pair<Tensor, Tensor> add_rms_norm(const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
auto y = Tensor::empty(a->shape(), a->dtype(), a->device());
auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device());
add_rms_norm_(y, residual_out, a, b, weight, epsilon);
return std::make_pair(y, residual_out);
}

void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon);
void add_rms_norm_(Tensor out, Tensor residual, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
AddRMSNorm::execute(out, residual, a, b, weight, epsilon);
}

void add_rms_norm_inplace(Tensor input, Tensor residual, const Tensor &weight, float epsilon) {
add_rms_norm_(input, residual, input, residual, weight, epsilon);
}

} // namespace infinicore::op
71 changes: 37 additions & 34 deletions src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc
Original file line number Diff line number Diff line change
@@ -1,50 +1,53 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/add_rms_norm.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

#include "../infiniop_impl.hpp"

namespace infinicore::op::add_rms_norm_impl::infiniop {

thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches(
100, // capacity
[](infiniopAddRMSNormDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc));
desc = nullptr;
}
});
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, AddRMSNorm, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, out, residual, a, b, weight;
float epsilon;
};

void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) {
void *plan(Tensor y, Tensor residual_out, const Tensor &a, const Tensor &b, const Tensor &weight, float epsilon) {
size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon);

auto device = context::getDevice();
auto &cache = caches.getCache(device);
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, AddRMSNorm,
seed, y->desc(), residual_out->desc(),
a->desc(), b->desc(), weight->desc(), epsilon);

INFINIOP_WORKSPACE_TENSOR(workspace, AddRMSNorm, descriptor);

auto desc_opt = cache.get(seed);
infiniopAddRMSNormDescriptor_t desc = nullptr;
auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(y),
graph::GraphTensor(residual_out),
graph::GraphTensor(a),
graph::GraphTensor(b),
graph::GraphTensor(weight),
epsilon};

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor(
context::getInfiniopHandle(device), &desc,
y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
return planned;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);

INFINICORE_CHECK_ERROR(infiniopAddRMSNorm(
desc, workspace->data(), workspace_size,
y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream()));
planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(),
planned->out->data(), planned->residual->data(), planned->a->data(), planned->b->data(), planned->weight->data(), context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

static bool registered = []() {
AddRMSNorm::dispatcher().registerAll(&calculate, false);
return true;
}();
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(AddRMSNorm, &plan, &run, &cleanup);

} // namespace infinicore::op::add_rms_norm_impl::infiniop
6 changes: 3 additions & 3 deletions src/infiniop/ops/add_rms_norm/add_rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t residual_out_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t weight_desc, \
float epsilon, \
infiniopTensorDescriptor_t residual_out_desc); \
float epsilon); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
void *residual_out, \
const void *a, \
const void *b, \
const void *weight, \
void *residual_out, \
void *stream) const; \
}; \
}
Expand Down
Loading