Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/swiglu.hpp"
#include "ops/silu_and_mul.hpp"
18 changes: 18 additions & 0 deletions include/infinicore/ops/silu_and_mul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

// 这个宏会自动定义 SiluAndMul 类,并包含:
// execute, dispatcher, plan_dispatcher, run_dispatcher, cleanup_dispatcher
// 以及对应的 schema 类型定义
INFINICORE_GRAPH_OP_CLASS(SiluAndMul, Tensor, Tensor);

// 全局辅助函数
Tensor silu_and_mul(Tensor x);
void silu_and_mul_(Tensor out, Tensor x);

} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "infiniop/ops/topkrouter.h"
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/ops/silu_and_mul.h"
#include "infiniop/tensor_descriptor.h"

#endif // __INFINIOP_API_H__
49 changes: 49 additions & 0 deletions include/infiniop/ops/silu_and_mul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef __INFINIOP_SILU_AND_MUL_API_H__
#define __INFINIOP_SILU_AND_MUL_API_H__

#include "../operator_descriptor.h"

// 定义描述符类型
typedef struct InfiniopDescriptor *infiniopSiluAndMulDescriptor_t;

/**
* @brief 创建 SiluAndMul 算子描述符
* * 公式: output = silu(input_front) * input_back
* 其中 input 形状为 [..., 2*d], output 形状为 [..., d]
*/
__C __export infiniStatus_t infiniopCreateSiluAndMulDescriptor(
infiniopHandle_t handle,
infiniopSiluAndMulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input);

/**
* @brief 获取算子执行所需的临时空间大小
*/
__C __export infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(
infiniopSiluAndMulDescriptor_t desc,
size_t *size);

/**
* @brief 执行 SiluAndMul 计算
* * @param workspace 临时空间指针
* @param workspace_size 临时空间大小
* @param output 输出张量数据指针 [..., d]
* @param input 输入张量数据指针 [..., 2*d]
* @param stream 硬件流指针 (如 musaStream_t)
*/
__C __export infiniStatus_t infiniopSiluAndMul(
infiniopSiluAndMulDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *output,
const void *input,
void *stream);

/**
* @brief 销毁描述符并释放相关资源
*/
__C __export infiniStatus_t infiniopDestroySiluAndMulDescriptor(
infiniopSiluAndMulDescriptor_t desc);

#endif // __INFINIOP_SILU_AND_MUL_API_H__
1 change: 1 addition & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import infinicore.context as context
import infinicore.nn as nn


# Import context functions
from infinicore.context import (
get_device,
Expand Down
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .rope import RopeAlgo, rope
from .silu import silu
from .swiglu import swiglu
from .silu_and_mul import silu_and_mul

__all__ = [
"causal_softmax",
Expand All @@ -17,4 +18,5 @@
"embedding",
"rope",
"RopeAlgo",
"silu_and_mul",
]
19 changes: 19 additions & 0 deletions python/infinicore/nn/functional/silu_and_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def silu_and_mul(input: Tensor, out=None) -> Tensor:
r"""Apply the SiLU and Mul (SwiGLU) function.

Formula: output = SiLU(input_gate) * input_up
Input shape: [..., 2*d], Output shape: [..., d]
"""

if out is None:
# 调用 C++ 非原地接口,内部处理输出 Tensor 的创建
return Tensor(_infinicore.silu_and_mul(input._underlying))

# 调用 C++ 原地/指定输出接口
_infinicore.silu_and_mul_(out._underlying, input._underlying)

return out
43 changes: 43 additions & 0 deletions src/infinicore/ops/silu_and_mul/silu_and_mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "infinicore/ops/silu_and_mul.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

// 实现分发器
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SiluAndMul);

// 构造函数:校验设备并分发
SiluAndMul::SiluAndMul(Tensor out, Tensor x) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, x);
// 根据设备类型(如 Moore, Cuda 等)路由到具体的实现
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, x);
}

// 执行接口:在图模式下记录或在即时模式下运行
void SiluAndMul::execute(Tensor out, Tensor x) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SiluAndMul, out, x);
}

// 非原地接口:负责推导输出形状并分配内存
Tensor silu_and_mul(Tensor x) {
Shape shape = x->shape();
size_t ndim = x->ndim();

// SwiGLU 逻辑:输出最后一维是输入的一半
if (shape[ndim - 1] % 2 != 0) {
throw std::runtime_error("SiluAndMul input last dim must be even.");
}
shape[ndim - 1] /= 2;

// 创建输出张量
auto out = Tensor::empty(shape, x->dtype(), x->device());
silu_and_mul_(out, x);
return out;
}

// 原地/指定输出接口
void silu_and_mul_(Tensor out, Tensor x) {
SiluAndMul::execute(out, x);
}

} // namespace infinicore::op
60 changes: 60 additions & 0 deletions src/infinicore/ops/silu_and_mul/silu_and_mul_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "../infiniop_impl.hpp"
#include "infinicore/ops/silu_and_mul.hpp"

namespace infinicore::op::silu_and_mul_impl::infiniop {

// 定义可缓存的描述符,用于避免频繁创建/销毁 infiniopDescriptor
INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SiluAndMul, 100);

// 定义图执行模式所需的元数据
struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, output, input;
};

// 预执行阶段:创建描述符并关联张量
void *plan(Tensor output, Tensor input) {
// 根据张量的描述符(形状、类型等)生成唯一 Hash Seed
size_t seed = hash_combine(output, input);

// 获取缓存的描述符或创建新描述符
INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, SiluAndMul,
seed, output->desc(), input->desc());

// 分配工作空间张量(SwiGLU 如果需要的话,由 descriptor->workspace_size 决定)
INFINIOP_WORKSPACE_TENSOR(workspace, SiluAndMul, descriptor);

auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(output),
graph::GraphTensor(input)};

return planned;
}

// 实际执行阶段
void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);

// 调用我们在之前步骤中实现的 infiniop 接口
INFINICORE_CHECK_ERROR(infiniopSiluAndMul(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->output->data(),
planned->input->data(),
context::getStream()));
}

// 清理逻辑
void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

// 注册算子到所有支持的设备
INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SiluAndMul, &plan, &run, &cleanup);

} // namespace infinicore::op::silu_and_mul_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/swiglu.hpp"
#include "ops/silu_and_mul.hpp"

namespace py = pybind11;

Expand All @@ -42,6 +43,7 @@ inline void bind(py::module &m) {
bind_swiglu(m);
bind_rope(m);
bind_embedding(m);
bind_silu_and_mul(m);
}

} // namespace infinicore::ops
31 changes: 31 additions & 0 deletions src/infinicore/pybind11/ops/silu_and_mul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/silu_and_mul.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_silu_and_mul(py::module &m) {
// 绑定非原地函数: Tensor silu_and_mul(Tensor input)
m.def("silu_and_mul",
&op::silu_and_mul,
py::arg("input"),
R"doc(
SiLU and Mul (SwiGLU) activation function.
Input should be [..., 2*d], output will be [..., d].
)doc");

// 绑定原地/指定输出函数: void silu_and_mul_(Tensor output, Tensor input)
m.def("silu_and_mul_",
&op::silu_and_mul_,
py::arg("output"),
py::arg("input"),
R"doc(
In-place or destination-specified SiLU and Mul (SwiGLU) activation function.
)doc");
}

} // namespace infinicore::ops
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ __device__ void causalSoftmaxKernel(
// 1 | * * * ... * * |
// 2 | * * * ... * * * |
// height: 3 col_id->
if (width + blockIdx.x >= threadIdx.x + height) {
if (width + blockIdx.x >= col + height) {
if constexpr (std::is_same_v<Tdata, half> || std::is_same_v<Tdata, cuda_bfloat16>) {
/*
* MUSA does not support CUDA's native `hexp` function.
Expand Down
Loading
Loading