Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/nn.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "nn/embedding.hpp"
#include "nn/layernorm.hpp"
#include "nn/linear.hpp"
#include "nn/rmsnorm.hpp"
60 changes: 60 additions & 0 deletions include/infinicore/nn/layernorm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include "module.hpp"
#include "../ops.hpp"

namespace infinicore::nn {

/**
* @brief Layer Normalization
*
* Applies LayerNorm over the last dimension.
*
* Formula: y = (x - mean) / sqrt(var + eps) * weight + bias
*/
class LayerNorm : public Module {
public:
/**
* @brief Construct a LayerNorm layer
*
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
* @param eps Small constant for numerical stability (default: 1e-5)
* @param dtype Data type for the weight/bias (default: DataType::F32)
* @param device Device to create the parameters on
*/
LayerNorm(size_t normalized_shape,
double eps = 1e-5,
const DataType &dtype = DataType::F32,
const Device &device = Device());

/**
* @brief Forward pass: apply LayerNorm
*
* @param x Input tensor of shape (*, normalized_shape)
* @return Normalized tensor with same shape as input
*/
Tensor forward(const Tensor &x) const;

// Module information
size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; }
DataType dtype() const { return dtype_; }

// String representation
std::string extra_repr() const;

// Accessors for parameters
Tensor weight() const { return weight_; }
Tensor bias() const { return bias_; }

protected:
INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias);

private:
size_t normalized_shape_;
double eps_;
DataType dtype_;
};

} // namespace infinicore::nn
8 changes: 8 additions & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
#include "ops/add_rms_norm.hpp"
#include "ops/attention.hpp"
#include "ops/causal_softmax.hpp"
#include "ops/conv2d.hpp"
#include "ops/gelu.hpp"
#include "ops/gelutanh.hpp"
#include "ops/layer_norm.hpp"
#include "ops/linear.hpp"
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/relu.hpp"
#include "ops/quickgelu.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/softmax.hpp"
#include "ops/swiglu.hpp"
38 changes: 38 additions & 0 deletions include/infinicore/ops/conv2d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

#include <cstddef>
#include <vector>

namespace infinicore::op {
class Conv2d {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor,
const size_t *, const size_t *, const size_t *, size_t);
static void execute(Tensor output,
Tensor input,
Tensor weight,
Tensor bias,
const size_t *pads,
const size_t *strides,
const size_t *dilations,
size_t n);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor conv2d(Tensor input,
Tensor weight,
Tensor bias,
const std::vector<size_t> &pads,
const std::vector<size_t> &strides,
const std::vector<size_t> &dilations);
void conv2d_(Tensor output,
Tensor input,
Tensor weight,
Tensor bias,
const std::vector<size_t> &pads,
const std::vector<size_t> &strides,
const std::vector<size_t> &dilations);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/gelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Gelu {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor gelu(Tensor input);
void gelu_(Tensor output, Tensor input);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/gelutanh.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class GeluTanh {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor gelu_tanh(Tensor input);
void gelu_tanh_(Tensor output, Tensor input);
} // namespace infinicore::op
28 changes: 28 additions & 0 deletions include/infinicore/ops/layer_norm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class LayerNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, float);
static void execute(Tensor output,
Tensor input_standardization,
Tensor input_std_deviation,
Tensor input,
Tensor weight,
Tensor bias,
float epsilon);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor layer_norm(Tensor input, Tensor weight, Tensor bias, float epsilon = 1e-5f);
void layer_norm_(Tensor output,
Tensor input_standardization,
Tensor input_std_deviation,
Tensor input,
Tensor weight,
Tensor bias,
float epsilon = 1e-5f);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/quickgelu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class QuickGelu {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor quick_gelu(Tensor input);
void quick_gelu_(Tensor output, Tensor input);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/relu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Relu {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor relu(Tensor input);
void relu_(Tensor output, Tensor input);
} // namespace infinicore::op
16 changes: 16 additions & 0 deletions include/infinicore/ops/softmax.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Softmax {
public:
using schema = void (*)(Tensor, Tensor, int);
static void execute(Tensor output, Tensor input, int axis);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor softmax(Tensor input, int axis = -1);
void softmax_(Tensor output, Tensor input, int axis = -1);
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/gelu.h"
#include "infiniop/ops/gelutanh.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/layer_norm.h"
#include "infiniop/ops/logsoftmax.h"
Expand All @@ -20,6 +21,7 @@
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/quickgelu.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
Expand Down
43 changes: 43 additions & 0 deletions include/infiniop/ops/gelutanh.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __INFINIOP_GELUTANH_API_H__
#define __INFINIOP_GELUTANH_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopGeluTanhDescriptor_t;

/**
* Create GELU-Tanh descriptor
*
* y = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
*/
__C __export infiniStatus_t infiniopCreateGeluTanhDescriptor(
infiniopHandle_t handle,
infiniopGeluTanhDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x);

/**
* Query workspace size
*/
__C __export infiniStatus_t infiniopGetGeluTanhWorkspaceSize(
infiniopGeluTanhDescriptor_t desc,
size_t *size);

/**
* Launch GELU-Tanh operator
*/
__C __export infiniStatus_t infiniopGeluTanh(
infiniopGeluTanhDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);

/**
* Destroy descriptor
*/
__C __export infiniStatus_t infiniopDestroyGeluTanhDescriptor(
infiniopGeluTanhDescriptor_t desc);

#endif
42 changes: 42 additions & 0 deletions include/infiniop/ops/quickgelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef __INFINIOP_QUICKGELU_API_H__
#define __INFINIOP_QUICKGELU_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopQuickGeluDescriptor_t;

/**
* Create QuickGELU descriptor
* y = x * sigmoid(1.702 * x)
*/
__C __export infiniStatus_t infiniopCreateQuickGeluDescriptor(
infiniopHandle_t handle,
infiniopQuickGeluDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x);

/**
* Query workspace size
*/
__C __export infiniStatus_t infiniopGetQuickGeluWorkspaceSize(
infiniopQuickGeluDescriptor_t desc,
size_t *size);

/**
* Launch QuickGELU operator
*/
__C __export infiniStatus_t infiniopQuickGelu(
infiniopQuickGeluDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);

/**
* Destroy descriptor
*/
__C __export infiniStatus_t infiniopDestroyQuickGeluDescriptor(
infiniopQuickGeluDescriptor_t desc);

#endif
30 changes: 30 additions & 0 deletions src/infinicore/nn/layernorm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "infinicore/nn/layernorm.hpp"

namespace infinicore::nn {

LayerNorm::LayerNorm(size_t normalized_shape,
double eps,
const DataType &dtype,
const Device &device)
: normalized_shape_(normalized_shape),
eps_(eps),
dtype_(dtype) {
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape_}, dtype_, device));
INFINICORE_NN_PARAMETER_INIT(bias, ({normalized_shape_}, dtype_, device));
auto weight_init = infinicore::Tensor::ones({normalized_shape_}, dtype_, device);
auto bias_init = infinicore::Tensor::zeros({normalized_shape_}, dtype_, device);
weight_->copy_from(weight_init);
bias_->copy_from(bias_init);
}

Tensor LayerNorm::forward(const Tensor &x) const {
return infinicore::op::layer_norm(x, weight_, bias_, static_cast<float>(eps_));
}

std::string LayerNorm::extra_repr() const {
return "normalized_shape=" + std::to_string(normalized_shape_) +
", eps=" + std::to_string(eps_) +
", dtype=" + infinicore::toString(dtype_);
}

} // namespace infinicore::nn
Loading