diff --git a/include/infinicore/graph/graph.hpp b/include/infinicore/graph/graph.hpp index c63b3272d..665ab6a17 100644 --- a/include/infinicore/graph/graph.hpp +++ b/include/infinicore/graph/graph.hpp @@ -15,10 +15,15 @@ class GraphTensor : public Tensor { }; class GraphOperator { +public: + virtual void run() const = 0; + virtual ~GraphOperator() = default; +}; +class DispatchableGraphOperator : public GraphOperator { public: - void run() const; - ~GraphOperator(); + void run() const override; + ~DispatchableGraphOperator() override; protected: using run_schema = void (*)(void *); @@ -45,7 +50,7 @@ class Graph { } // namespace infinicore::graph #define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \ - class __OP_NAME__ : public graph::GraphOperator { \ + class __OP_NAME__ : public graph::DispatchableGraphOperator { \ public: \ using schema = void (*)(__VA_ARGS__); \ using plan_schema = void *(*)(__VA_ARGS__); \ @@ -75,12 +80,12 @@ class Graph { runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \ deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__); -#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \ - auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \ - if (context::isGraphRecording()) { \ - context::addGraphOperator(op); \ - } else { \ - op->run(); \ +#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \ + auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \ + if (context::isGraphRecording()) { \ + context::addGraphOperator(___op); \ + } else { \ + ___op->run(); \ } #define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \ diff --git a/include/infinicore/ops/add.hpp b/include/infinicore/ops/add.hpp index 1dd5df0ff..528cca18a 100644 --- a/include/infinicore/ops/add.hpp +++ b/include/infinicore/ops/add.hpp @@ -1,17 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Add { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; -Tensor add(Tensor a, Tensor b); -void add_(Tensor c, Tensor a, Tensor b); -Tensor operator+(Tensor a, Tensor b); +INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &); + +Tensor add(const Tensor &a, const Tensor &b); +void add_(Tensor c, const Tensor &a, const Tensor &b); + } // namespace infinicore::op diff --git a/include/infinicore/ops/causal_softmax.hpp b/include/infinicore/ops/causal_softmax.hpp index ae40d521c..2646852af 100644 --- a/include/infinicore/ops/causal_softmax.hpp +++ b/include/infinicore/ops/causal_softmax.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class CausalSoftmax { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor output, Tensor input); - static common::OpDispatcher &dispatcher(); -}; -Tensor causal_softmax(Tensor input); -void causal_softmax_(Tensor output, Tensor input); +INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &); + +Tensor causal_softmax(const Tensor &input); +void causal_softmax_(Tensor output, const Tensor &input); + } // namespace infinicore::op diff --git a/include/infinicore/ops/distributed/allreduce.hpp b/include/infinicore/ops/distributed/allreduce.hpp new file mode 100644 index 000000000..39f74243a --- /dev/null +++ b/include/infinicore/ops/distributed/allreduce.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "../../device.hpp" +#include "../../graph/graph.hpp" +#include "../common/op.hpp" + +#include + +namespace infinicore::op::distributed { +class AllReduce : public graph::GraphOperator { +public: + AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); + ~AllReduce(); + void run() const override; + static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); + +private: + void *planned_meta_; +}; + +Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); +void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator); + +} // namespace infinicore::op::distributed diff --git a/include/infinicore/ops/gemm.hpp b/include/infinicore/ops/gemm.hpp index 481d47cf6..4f76cee26 100644 --- a/include/infinicore/ops/gemm.hpp +++ b/include/infinicore/ops/gemm.hpp @@ -6,9 +6,9 @@ namespace infinicore::op { -INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float); +INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float); -Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f); -void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta); +Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f); +void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta); } // namespace infinicore::op diff --git a/include/infinicore/ops/mul.hpp b/include/infinicore/ops/mul.hpp index 83416bbd9..2eb480ddb 100644 --- a/include/infinicore/ops/mul.hpp +++ b/include/infinicore/ops/mul.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Mul { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; -Tensor mul(Tensor a, Tensor b); -void mul_(Tensor c, Tensor a, Tensor b); +INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &); + +Tensor mul(const Tensor &a, const Tensor &b); +void mul_(Tensor c, const Tensor &a, const Tensor &b); + } // namespace infinicore::op diff --git a/include/infinicore/ops/paged_attention.hpp b/include/infinicore/ops/paged_attention.hpp index 54d61fa89..8c906c95e 100644 --- a/include/infinicore/ops/paged_attention.hpp +++ b/include/infinicore/ops/paged_attention.hpp @@ -1,18 +1,20 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" #include namespace infinicore::op { -class PagedAttention { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional, float); - static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional, float); + +Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale); + +void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale); -Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale); -void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale); } // namespace infinicore::op diff --git a/include/infinicore/ops/paged_caching.hpp b/include/infinicore/ops/paged_caching.hpp index e357cda38..403b4b738 100644 --- a/include/infinicore/ops/paged_caching.hpp +++ b/include/infinicore/ops/paged_caching.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class PagedCaching { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); - static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping); - static common::OpDispatcher &dispatcher(); -}; +INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &); -void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping); +void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping); } // namespace infinicore::op diff --git a/include/infinicore/ops/rearrange.hpp b/include/infinicore/ops/rearrange.hpp index 3576365e0..5db983ef8 100644 --- a/include/infinicore/ops/rearrange.hpp +++ b/include/infinicore/ops/rearrange.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class Rearrange { -public: - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor y, Tensor x); - static common::OpDispatcher &dispatcher(); -}; -Tensor rearrange(Tensor x); -void rearrange_(Tensor y, Tensor x); +INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &); + +Tensor rearrange(const Tensor &x); +void rearrange_(Tensor y, const Tensor &x); + } // namespace infinicore::op diff --git a/include/infinicore/ops/rms_norm.hpp b/include/infinicore/ops/rms_norm.hpp index 1212c446e..c7b2b2d72 100644 --- a/include/infinicore/ops/rms_norm.hpp +++ b/include/infinicore/ops/rms_norm.hpp @@ -1,16 +1,14 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class RMSNorm { -public: - using schema = void (*)(Tensor, Tensor, Tensor, float); - static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f); - static common::OpDispatcher &dispatcher(); -}; -Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f); -void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f); +INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float); + +Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f); +void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f); + } // namespace infinicore::op diff --git a/include/infinicore/ops/rope.hpp b/include/infinicore/ops/rope.hpp index a5f7792b9..8fd630ce1 100644 --- a/include/infinicore/ops/rope.hpp +++ b/include/infinicore/ops/rope.hpp @@ -1,21 +1,28 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "../nn/rope.hpp" #include "../tensor.hpp" #include "common/op.hpp" namespace infinicore::op { -class RoPE { -public: - using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo); - static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo); - static common::OpDispatcher &dispatcher(); -}; -// Internal function -void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo); +INFINICORE_GRAPH_OP_CLASS(RoPE, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo); + +// Internal +void rope_(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo); + +// Public API +Tensor rope(const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo); -// Public API that uses infinicore::nn::RoPE::Algo -Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo); } // namespace infinicore::op diff --git a/include/infinicore/ops/swiglu.hpp b/include/infinicore/ops/swiglu.hpp index 47a3e0f44..7aa77e632 100644 --- a/include/infinicore/ops/swiglu.hpp +++ b/include/infinicore/ops/swiglu.hpp @@ -1,16 +1,15 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" +#include "../tensor.hpp" #include "common/op.hpp" namespace infinicore::op { -class SwiGLU { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor c, Tensor a, Tensor b); - static common::OpDispatcher &dispatcher(); -}; -Tensor swiglu(Tensor a, Tensor b); -void swiglu_(Tensor c, Tensor a, Tensor b); +INFINICORE_GRAPH_OP_CLASS(SwiGLU, Tensor, const Tensor &, const Tensor &); + +Tensor swiglu(const Tensor &a, const Tensor &b); +void swiglu_(Tensor c, const Tensor &a, const Tensor &b); + } // namespace infinicore::op diff --git a/src/infinicore/graph/graph.cc b/src/infinicore/graph/graph.cc index 86944af36..c502e4e47 100644 --- a/src/infinicore/graph/graph.cc +++ b/src/infinicore/graph/graph.cc @@ -15,11 +15,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) { * GraphOperator * ========================= */ -void GraphOperator::run() const { +void DispatchableGraphOperator::run() const { runner_(planned_meta_); } -GraphOperator::~GraphOperator() { +DispatchableGraphOperator::~DispatchableGraphOperator() { if (deleter_) { deleter_(&planned_meta_); } diff --git a/src/infinicore/nn/linear.cc b/src/infinicore/nn/linear.cc index bb4fc29b1..0be993699 100644 --- a/src/infinicore/nn/linear.cc +++ b/src/infinicore/nn/linear.cc @@ -1,6 +1,7 @@ #include "infinicore/nn/linear.hpp" #include "../utils.hpp" #include "infinicore/ops.hpp" +#include "infinicore/ops/distributed/allreduce.hpp" #include "infinicore/ops/linear.hpp" #include #include @@ -102,9 +103,6 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur } else { bias_ = Parameter(); // Default constructed empty parameter } - - // SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", - // in_features, out_features, bias, static_cast(dtype_)); } Tensor ColumnParallelLinear::forward(Tensor &input) const { @@ -138,26 +136,13 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo } else { bias_ = Parameter(); // Default constructed empty parameter } - - // SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}", - // in_features, out_features, bias, static_cast(dtype_)); } Tensor RowParallelLinear::forward(Tensor &input) const { auto output = BaseLinear::forward(input); if ((tp_size_ > 1) && (communicator_ != nullptr)) { - - Size count = output->numel(); - DataType type = output->dtype(); - - infinirtStream_t stream = infinicore::context::getStream(); - - INFINICORE_CHECK_ERROR(infinicclAllReduce(output->data(), output->data(), count, static_cast(static_cast(type)), - INFINICCL_SUM, communicator_, stream)); - INFINICORE_CHECK_ERROR(infinirtStreamSynchronize(stream)); - - // RUN_INFINI(infinirtStreamSynchronize(stream)); + op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_); } return output; } diff --git a/src/infinicore/ops/add/add.cc b/src/infinicore/ops/add/add.cc index ef776d632..815a2de27 100644 --- a/src/infinicore/ops/add/add.cc +++ b/src/infinicore/ops/add/add.cc @@ -3,24 +3,24 @@ namespace infinicore::op { -common::OpDispatcher &Add::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Add); -void Add::execute(Tensor c, Tensor a, Tensor b) { +Add::Add(Tensor c, const Tensor &a, const Tensor &b) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); - infinicore::context::setDevice(c->device()); - dispatcher().lookup(c->device().getType())(c, a, b); + INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b); } -Tensor add(Tensor a, Tensor b) { +void Add::execute(Tensor c, const Tensor &a, const Tensor &b) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Add, c, a, b); +} + +Tensor add(const Tensor &a, const Tensor &b) { auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); add_(c, a, b); return c; } -void add_(Tensor c, Tensor a, Tensor b) { +void add_(Tensor c, const Tensor &a, const Tensor &b) { Add::execute(c, a, b); } diff --git a/src/infinicore/ops/add/add_infiniop.cc b/src/infinicore/ops/add/add_infiniop.cc index 29c36770c..bb377d667 100644 --- a/src/infinicore/ops/add/add_infiniop.cc +++ b/src/infinicore/ops/add/add_infiniop.cc @@ -1,50 +1,52 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/add.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::add_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopAddDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyAddDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Add, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, c, a, b; +}; -void calculate(Tensor c, Tensor a, Tensor b) { +void *plan(Tensor c, const Tensor &a, const Tensor &b) { size_t seed = hash_combine(c, b, a); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Add, + seed, + c->desc(), a->desc(), b->desc()); - auto desc_opt = cache.get(seed); - infiniopAddDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Add, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( - context::getInfiniopHandle(device), &desc, - c->desc(), a->desc(), b->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(c), + graph::GraphTensor(a), + graph::GraphTensor(b)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetAddWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopAdd( - desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->c->data(), + planned->a->data(), + planned->b->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Add::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Add, &plan, &run, &cleanup); } // namespace infinicore::op::add_impl::infiniop diff --git a/src/infinicore/ops/causal_softmax/causal_softmax.cc b/src/infinicore/ops/causal_softmax/causal_softmax.cc index 3194dff94..328ff390e 100644 --- a/src/infinicore/ops/causal_softmax/causal_softmax.cc +++ b/src/infinicore/ops/causal_softmax/causal_softmax.cc @@ -1,37 +1,27 @@ #include "infinicore/ops/causal_softmax.hpp" - #include "../../utils.hpp" -#include - namespace infinicore::op { -common::OpDispatcher &CausalSoftmax::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(CausalSoftmax); -void CausalSoftmax::execute(Tensor output, Tensor input) { +CausalSoftmax::CausalSoftmax(Tensor output, const Tensor &input) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); - infinicore::context::setDevice(output->device()); - auto device_type = output->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No CausalSoftmax implementation found for device type: " + std::to_string(static_cast(device_type))); - } + INFINICORE_GRAPH_OP_DISPATCH(output->device().getType(), output, input); +} - func(output, input); +void CausalSoftmax::execute(Tensor output, const Tensor &input) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(CausalSoftmax, output, input); } -Tensor causal_softmax(Tensor input) { - Shape shape = input->shape(); - auto output = Tensor::empty(shape, input->dtype(), input->device()); +Tensor causal_softmax(const Tensor &input) { + auto output = Tensor::empty(input->shape(), input->dtype(), input->device()); causal_softmax_(output, input); return output; } -void causal_softmax_(Tensor output, Tensor input) { +void causal_softmax_(Tensor output, const Tensor &input) { CausalSoftmax::execute(output, input); } + } // namespace infinicore::op diff --git a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc index 082d0e642..e0a0595fb 100644 --- a/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc +++ b/src/infinicore/ops/causal_softmax/causal_softmax_infiniop.cc @@ -1,50 +1,49 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" #include "infinicore/ops/causal_softmax.hpp" -#include "infinicore/ops/common/cache.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::causal_softmax_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopCausalSoftmaxDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyCausalSoftmaxDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, CausalSoftmax, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, output, input; +}; -void calculate(Tensor output, Tensor input) { +void *plan(Tensor output, const Tensor &input) { size_t seed = hash_combine(output, input); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, CausalSoftmax, + seed, output->desc(), input->desc()); - auto desc_opt = cache.get(seed); - infiniopCausalSoftmaxDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, CausalSoftmax, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( - context::getInfiniopHandle(device), &desc, - output->desc(), input->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(output), + graph::GraphTensor(input)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetCausalSoftmaxWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopCausalSoftmax( - desc, workspace->data(), workspace_size, - output->data(), input->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->output->data(), + planned->input->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - CausalSoftmax::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(CausalSoftmax, &plan, &run, &cleanup); } // namespace infinicore::op::causal_softmax_impl::infiniop diff --git a/src/infinicore/ops/distributed/allreduce.cc b/src/infinicore/ops/distributed/allreduce.cc new file mode 100644 index 000000000..ddfc238c9 --- /dev/null +++ b/src/infinicore/ops/distributed/allreduce.cc @@ -0,0 +1,50 @@ +#include "infinicore/ops/distributed/allreduce.hpp" +#include "../../utils.hpp" + +namespace infinicore::op::distributed { + +struct PlannedMeta { + graph::GraphTensor output, input; + infinicclReduceOp_t op; + infinicclComm_t communicator; +}; + +AllReduce::AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); + INFINICORE_ASSERT(output->is_contiguous() && input->is_contiguous()); + INFINICORE_ASSERT(output->numel() == input->numel()); + planned_meta_ = new PlannedMeta{graph::GraphTensor(output), graph::GraphTensor(input), op, communicator}; +} +AllReduce::~AllReduce() { + if (planned_meta_) { + PlannedMeta *meta = reinterpret_cast(planned_meta_); + delete meta; + } +} + +void AllReduce::run() const { + PlannedMeta *meta = reinterpret_cast(planned_meta_); + + INFINICORE_CHECK_ERROR(infinicclAllReduce(meta->input->data(), + meta->output->data(), + meta->input->numel(), + static_cast(static_cast(meta->input->dtype())), + meta->op, + meta->communicator, + infinicore::context::getStream())); +} + +void AllReduce::execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(AllReduce, output, input, op, communicator); +} + +Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + auto output = Tensor::empty(input->shape(), input->dtype(), input->device()); + allreduce_(output, input, op, communicator); + return output; +} + +void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator) { + AllReduce::execute(output, input, op, communicator); +} +} // namespace infinicore::op::distributed diff --git a/src/infinicore/ops/gemm/gemm.cc b/src/infinicore/ops/gemm/gemm.cc index e2b3924f7..765bc869f 100644 --- a/src/infinicore/ops/gemm/gemm.cc +++ b/src/infinicore/ops/gemm/gemm.cc @@ -5,16 +5,16 @@ namespace infinicore::op { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Gemm); -Gemm::Gemm(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +Gemm::Gemm(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b, alpha, beta); } -void Gemm::execute(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +void Gemm::execute(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { INFINICORE_GRAPH_OP_RECORD_OR_RUN(Gemm, c, a, b, alpha, beta); } -Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { +Tensor gemm(const Tensor &a, const Tensor &b, float alpha, float beta) { Shape shape = a->shape(); Size size = a->ndim(); shape[size - 1] = b->size(size - 1); @@ -23,7 +23,7 @@ Tensor gemm(Tensor a, Tensor b, float alpha, float beta) { return c; } -void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { Gemm::execute(c, a, b, alpha, beta); } diff --git a/src/infinicore/ops/gemm/gemm_infiniop.cc b/src/infinicore/ops/gemm/gemm_infiniop.cc index 670fdbc2a..33a7271c0 100644 --- a/src/infinicore/ops/gemm/gemm_infiniop.cc +++ b/src/infinicore/ops/gemm/gemm_infiniop.cc @@ -11,7 +11,7 @@ struct PlannedMeta { float alpha, beta; }; -void *plan(Tensor c, Tensor a, Tensor b, float alpha, float beta) { +void *plan(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta) { size_t seed = hash_combine(c, a, b); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( diff --git a/src/infinicore/ops/infiniop_impl.hpp b/src/infinicore/ops/infiniop_impl.hpp index 2bf38c8c6..67c09554c 100644 --- a/src/infinicore/ops/infiniop_impl.hpp +++ b/src/infinicore/ops/infiniop_impl.hpp @@ -5,23 +5,46 @@ #include "infinicore/ops/common/cache.hpp" #include -#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \ - struct __DESC_TYPE__ { \ - infiniop##__OP_NAME__##Descriptor_t desc; \ - Descriptor(infiniop##__OP_NAME__##Descriptor_t desc) : desc(desc) {} \ - ~Descriptor() { \ - if (desc != nullptr) { \ - infiniopDestroy##__OP_NAME__##Descriptor(desc); \ - desc = nullptr; \ - } \ - } \ - }; \ - \ - thread_local common::OpCache> \ - caches( \ - __SIZE__, \ - [](std::shared_ptr<__DESC_TYPE__> &desc) { \ - desc = nullptr; \ +#define INFINIOP_CACHABLE_DESCRIPTOR(__DESC_TYPE__, __OP_NAME__, __SIZE__) \ + struct __DESC_TYPE__ { \ + infiniop##__OP_NAME__##Descriptor_t desc = nullptr; \ + \ + explicit __DESC_TYPE__(infiniop##__OP_NAME__##Descriptor_t d) \ + : desc(d) {} \ + \ + /* non-copyable */ \ + __DESC_TYPE__(const __DESC_TYPE__ &) = delete; \ + __DESC_TYPE__ &operator=(const __DESC_TYPE__ &) = delete; \ + \ + /* movable */ \ + __DESC_TYPE__(__DESC_TYPE__ &&other) noexcept \ + : desc(other.desc) { \ + other.desc = nullptr; \ + } \ + \ + __DESC_TYPE__ &operator=(__DESC_TYPE__ &&other) noexcept { \ + if (this != &other) { \ + if (desc != nullptr) { \ + infiniopDestroy##__OP_NAME__##Descriptor(desc); \ + } \ + desc = other.desc; \ + other.desc = nullptr; \ + } \ + return *this; \ + } \ + \ + ~__DESC_TYPE__() { \ + if (desc != nullptr) { \ + infiniopDestroy##__OP_NAME__##Descriptor(desc); \ + } \ + } \ + }; \ + \ + thread_local common::OpCache> \ + caches( \ + __SIZE__, \ + [](std::shared_ptr<__DESC_TYPE__> &desc) { \ + desc = nullptr; \ }); #define INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(__DESC_TYPE__, __DESC_NAME__, __INFINIOP_NAME__, __HASH_KEY__, ...) \ diff --git a/src/infinicore/ops/mul/mul.cc b/src/infinicore/ops/mul/mul.cc index 736e44269..6923fed9c 100644 --- a/src/infinicore/ops/mul/mul.cc +++ b/src/infinicore/ops/mul/mul.cc @@ -1,27 +1,26 @@ #include "infinicore/ops/mul.hpp" - #include "../../utils.hpp" namespace infinicore::op { -common::OpDispatcher &Mul::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Mul); -void Mul::execute(Tensor c, Tensor a, Tensor b) { +Mul::Mul(Tensor c, const Tensor &a, const Tensor &b) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); - infinicore::context::setDevice(c->device()); - dispatcher().lookup(c->device().getType())(c, a, b); + INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b); +} + +void Mul::execute(Tensor c, const Tensor &a, const Tensor &b) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(Mul, c, a, b); } -Tensor mul(Tensor a, Tensor b) { +Tensor mul(const Tensor &a, const Tensor &b) { auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); mul_(c, a, b); return c; } -void mul_(Tensor c, Tensor a, Tensor b) { +void mul_(Tensor c, const Tensor &a, const Tensor &b) { Mul::execute(c, a, b); } diff --git a/src/infinicore/ops/mul/mul_infiniop.cc b/src/infinicore/ops/mul/mul_infiniop.cc index 885a5f842..39a7bd87d 100644 --- a/src/infinicore/ops/mul/mul_infiniop.cc +++ b/src/infinicore/ops/mul/mul_infiniop.cc @@ -1,50 +1,51 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/mul.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::mul_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopMulDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyMulDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Mul, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, c, a, b; +}; -void calculate(Tensor c, Tensor a, Tensor b) { +void *plan(Tensor c, const Tensor &a, const Tensor &b) { size_t seed = hash_combine(c, b, a); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Mul, + seed, c->desc(), a->desc(), b->desc()); - auto desc_opt = cache.get(seed); - infiniopMulDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, Mul, descriptor); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( - context::getInfiniopHandle(device), &desc, - c->desc(), a->desc(), b->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(c), + graph::GraphTensor(a), + graph::GraphTensor(b)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetMulWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopMul( - desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), context::getStream())); + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->c->data(), + planned->a->data(), + planned->b->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - Mul::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Mul, &plan, &run, &cleanup); } // namespace infinicore::op::mul_impl::infiniop diff --git a/src/infinicore/ops/paged_attention/paged_attention.cc b/src/infinicore/ops/paged_attention/paged_attention.cc index 171614087..60de2ae66 100644 --- a/src/infinicore/ops/paged_attention/paged_attention.cc +++ b/src/infinicore/ops/paged_attention/paged_attention.cc @@ -1,27 +1,37 @@ #include "infinicore/ops/paged_attention.hpp" - #include "../../utils.hpp" namespace infinicore::op { -common::OpDispatcher &PagedAttention::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PagedAttention); -void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { +PagedAttention::PagedAttention(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens); - infinicore::context::setDevice(out->device()); - dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); +} + +void PagedAttention::execute(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + PagedAttention, + out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); } -Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { +Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); return out; } -void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { +void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &kv_lens, + std::optional alibi_slopes, float scale) { PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); } diff --git a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc index 3d367c5bb..733733a6b 100644 --- a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc +++ b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc @@ -1,54 +1,68 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/paged_attention.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::paged_attention_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopPagedAttentionDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional alibi_slopes, float scale) { - size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopPagedAttentionDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor( - context::getInfiniopHandle(device), &desc, - out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), kv_lens->desc(), - alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, - scale)); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopPagedAttention( - desc, workspace->data(), workspace_size, - out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), kv_lens->data(), - alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, - context::getStream())); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PagedAttention, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, q, k_cache, v_cache, block_tables, cache_lens; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache, + const Tensor &block_tables, const Tensor &cache_lens, + std::optional alibi_slopes, float scale) { + size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, PagedAttention, + seed, + out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), + block_tables->desc(), cache_lens->desc(), + alibi_slopes ? alibi_slopes.value()->desc() : nullptr, + scale); + + INFINIOP_WORKSPACE_TENSOR(workspace, PagedAttention, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(block_tables), + graph::GraphTensor(cache_lens), + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopPagedAttention( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->out->data(), + p->q->data(), + p->k_cache->data(), + p->v_cache->data(), + p->block_tables->data(), + p->cache_lens->data(), + p->alibi_slopes.has_value() ? p->alibi_slopes.value()->data() : nullptr, + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - PagedAttention::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PagedAttention, &plan, &run, &cleanup); } // namespace infinicore::op::paged_attention_impl::infiniop diff --git a/src/infinicore/ops/paged_caching/paged_caching.cc b/src/infinicore/ops/paged_caching/paged_caching.cc index cc14bf236..afc8bf0c6 100644 --- a/src/infinicore/ops/paged_caching/paged_caching.cc +++ b/src/infinicore/ops/paged_caching/paged_caching.cc @@ -1,21 +1,20 @@ #include "infinicore/ops/paged_caching.hpp" - #include "../../utils.hpp" namespace infinicore::op { -common::OpDispatcher &PagedCaching::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(PagedCaching); -void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { +PagedCaching::PagedCaching(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping); - infinicore::context::setDevice(k_cache->device()); - dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping); + INFINICORE_GRAPH_OP_DISPATCH(k->device().getType(), k_cache, v_cache, k, v, slot_mapping); +} + +void PagedCaching::execute(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(PagedCaching, k_cache, v_cache, k, v, slot_mapping); } -void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { +void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping); } diff --git a/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc b/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc index 7dcaf47a0..5e8be049a 100644 --- a/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc +++ b/src/infinicore/ops/paged_caching/paged_caching_infiniop.cc @@ -1,50 +1,57 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/paged_caching.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::paged_caching_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopPagedCachingDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyPagedCachingDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) { - size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopPagedCachingDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor( - context::getInfiniopHandle(device), &desc, - k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopPagedCaching( - desc, workspace->data(), workspace_size, - k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream())); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, PagedCaching, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + + graph::GraphTensor workspace, k_cache, v_cache, k, v, slot_mapping; +}; + +void *plan(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping) { + size_t key = hash_combine(k_cache, v_cache, k, v, slot_mapping); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, PagedCaching, + key, k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, PagedCaching, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(slot_mapping)}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopPagedCaching( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->k_cache->data(), + p->v_cache->data(), + p->k->data(), + p->v->data(), + p->slot_mapping->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - PagedCaching::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(PagedCaching, &plan, &run, &cleanup); } // namespace infinicore::op::paged_caching_impl::infiniop diff --git a/src/infinicore/ops/rearrange/rearrange.cc b/src/infinicore/ops/rearrange/rearrange.cc index c70a9e930..191d0871f 100644 --- a/src/infinicore/ops/rearrange/rearrange.cc +++ b/src/infinicore/ops/rearrange/rearrange.cc @@ -3,24 +3,30 @@ namespace infinicore::op { -common::OpDispatcher &Rearrange::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Rearrange); -void Rearrange::execute(Tensor y, Tensor x) { +Rearrange::Rearrange(Tensor y, const Tensor &x) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, x); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x); } -Tensor rearrange(Tensor x) { +void Rearrange::execute(Tensor y, const Tensor &x) { + auto op = std::make_shared(y, x); + if (context::isGraphRecording()) { + context::addGraphOperator(op); + } else { + op->run(); + } +} + +Tensor rearrange(const Tensor &x) { auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); rearrange_(y, x); return y; } -void rearrange_(Tensor y, Tensor x) { +void rearrange_(Tensor y, const Tensor &x) { Rearrange::execute(y, x); } + } // namespace infinicore::op diff --git a/src/infinicore/ops/rearrange/rearrange_infiniop.cc b/src/infinicore/ops/rearrange/rearrange_infiniop.cc index 71b43f027..f30b09e79 100644 --- a/src/infinicore/ops/rearrange/rearrange_infiniop.cc +++ b/src/infinicore/ops/rearrange/rearrange_infiniop.cc @@ -1,47 +1,46 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rearrange.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::rearrange_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopRearrangeDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRearrangeDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, Rearrange, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor y, x; +}; -void calculate(Tensor y, Tensor x) { +void *plan(Tensor y, const Tensor &x) { size_t seed = hash_combine(y, x); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, Rearrange, + seed, y->desc(), + x->desc()); - auto desc_opt = cache.get(seed); - infiniopRearrangeDescriptor_t desc = nullptr; + return new PlannedMeta{ + descriptor, + graph::GraphTensor(y), + graph::GraphTensor(x)}; +} - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(device), &desc, y->desc(), x->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR( infiniopRearrange( - desc, - y->data(), - x->data(), + planned->descriptor->desc, + planned->y->data(), + planned->x->data(), context::getStream())); } -static bool registered = []() { - Rearrange::dispatcher().registerAll(&calculate, false); - return true; -}(); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(Rearrange, &plan, &run, &cleanup); } // namespace infinicore::op::rearrange_impl::infiniop diff --git a/src/infinicore/ops/rms_norm/rms_norm.cc b/src/infinicore/ops/rms_norm/rms_norm.cc index 20e598056..0f8f2e57b 100644 --- a/src/infinicore/ops/rms_norm/rms_norm.cc +++ b/src/infinicore/ops/rms_norm/rms_norm.cc @@ -1,27 +1,25 @@ #include "infinicore/ops/rms_norm.hpp" - #include "../../utils.hpp" namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RMSNorm); -common::OpDispatcher &RMSNorm::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void RMSNorm::execute(Tensor y, Tensor x, Tensor weight, float epsilon) { +RMSNorm::RMSNorm(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x, weight); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, x, weight, epsilon); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x, weight, epsilon); +} + +void RMSNorm::execute(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(RMSNorm, y, x, weight, epsilon); } -Tensor rms_norm(Tensor x, Tensor weight, float epsilon) { +Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon) { auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); rms_norm_(y, x, weight, epsilon); return y; } -void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon) { +void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { RMSNorm::execute(y, x, weight, epsilon); } diff --git a/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc b/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc index 17b0ad888..9e4622a28 100644 --- a/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc +++ b/src/infinicore/ops/rms_norm/rms_norm_infiniop.cc @@ -1,50 +1,55 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rms_norm.hpp" -#include -namespace infinicore::op::rms_norm_impl::infiniop { +#include "../infiniop_impl.hpp" -thread_local common::OpCache caches( - 100, // capacity - [](infiniopRMSNormDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRMSNormDescriptor(desc)); - desc = nullptr; - } - }); +namespace infinicore::op::rms_norm_impl::infiniop { -void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { - size_t seed = hash_combine(y, x, weight, epsilon); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RMSNorm, 100); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, y, x, weight; +}; - auto desc_opt = cache.get(seed); - infiniopRMSNormDescriptor_t desc = nullptr; +void *plan(Tensor y, const Tensor &x, const Tensor &weight, float epsilon) { + size_t seed = hash_combine(y, x, weight, epsilon); - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( - context::getInfiniopHandle(device), &desc, - y->desc(), x->desc(), weight->desc(), epsilon)); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, RMSNorm, + seed, y->desc(), + x->desc(), + weight->desc(), + epsilon); + + INFINIOP_WORKSPACE_TENSOR(workspace, RMSNorm, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(x), + graph::GraphTensor(weight)}; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetRMSNormWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopRMSNorm( + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->y->data(), + planned->x->data(), + planned->weight->data(), + context::getStream())); +} - INFINICORE_CHECK_ERROR(infiniopRMSNorm( - desc, workspace->data(), workspace_size, - y->data(), x->data(), weight->data(), context::getStream())); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - RMSNorm::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RMSNorm, &plan, &run, &cleanup); } // namespace infinicore::op::rms_norm_impl::infiniop diff --git a/src/infinicore/ops/rope/rope.cc b/src/infinicore/ops/rope/rope.cc index e0a187db3..d28951b7d 100644 --- a/src/infinicore/ops/rope/rope.cc +++ b/src/infinicore/ops/rope/rope.cc @@ -1,37 +1,44 @@ #include "infinicore/ops/rope.hpp" - #include "../../utils.hpp" -#include "infinicore/context/context.hpp" - -#include namespace infinicore::op { -common::OpDispatcher &RoPE::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(RoPE); -void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { +RoPE::RoPE(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(x_out, x, pos, sin_table, cos_table); - infinicore::context::setDevice(x_out->device()); - auto device_type = x_out->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No RoPE implementation found for device type: " + std::to_string(static_cast(device_type))); - } + INFINICORE_GRAPH_OP_DISPATCH(x_out->device().getType(), x_out, x, pos, sin_table, cos_table, algo); +} - func(x_out, x, pos, sin_table, cos_table, algo); +void RoPE::execute(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(RoPE, x_out, x, pos, sin_table, cos_table, algo); } -void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { +void rope_(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { RoPE::execute(x_out, x, pos, sin_table, cos_table, algo); } -Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { - Shape shape = x->shape(); - auto x_out = Tensor::empty(shape, x->dtype(), x->device()); +Tensor rope(const Tensor &x, + const Tensor &pos, + const Tensor &sin_table, + const Tensor &cos_table, + infinicore::nn::RoPE::Algo algo) { + auto x_out = Tensor::empty(x->shape(), x->dtype(), x->device()); rope_(x_out, x, pos, sin_table, cos_table, algo); return x_out; } diff --git a/src/infinicore/ops/rope/rope_infiniop.cc b/src/infinicore/ops/rope/rope_infiniop.cc index 412daa925..850c2d0a2 100644 --- a/src/infinicore/ops/rope/rope_infiniop.cc +++ b/src/infinicore/ops/rope/rope_infiniop.cc @@ -1,69 +1,81 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/rope.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::rope_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopRoPEDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyRoPEDescriptor(desc)); - desc = nullptr; - } - }); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, RoPE, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace; + graph::GraphTensor x_out; + graph::GraphTensor x; + graph::GraphTensor pos; + graph::GraphTensor sin; + graph::GraphTensor cos; +}; -void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) { - // Convert infinicore::nn::RoPE::Algo to infiniopRoPEAlgo_t - infiniopRoPEAlgo_t infiniop_algo; +static infiniopRoPEAlgo_t to_infiniop_algo(infinicore::nn::RoPE::Algo algo) { switch (algo) { case infinicore::nn::RoPE::Algo::GPT_J: - infiniop_algo = INFINIOP_ROPE_ALGO_GPT_J; - break; + return INFINIOP_ROPE_ALGO_GPT_J; case infinicore::nn::RoPE::Algo::GPT_NEOX: - infiniop_algo = INFINIOP_ROPE_ALGO_GPT_NEOX; - break; + return INFINIOP_ROPE_ALGO_GPT_NEOX; default: - throw std::runtime_error("Unsupported RoPE algorithm: " + std::to_string(static_cast(algo))); + throw std::runtime_error("Unsupported RoPE algorithm"); } +} - // Create hash key for descriptor caching - size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache); - hash_combine(key, std::hash()(static_cast(infiniop_algo))); +void *plan(Tensor x_out, + const Tensor &x, + const Tensor &pos, + const Tensor &sin, + const Tensor &cos, + infinicore::nn::RoPE::Algo algo) { + auto infiniop_algo = to_infiniop_algo(algo); + size_t key = hash_combine(x_out, x, pos, sin, cos, static_cast(infiniop_algo)); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, RoPE, key, x_out->desc(), + x->desc(), + pos->desc(), + sin->desc(), + cos->desc(), + infiniop_algo); - auto desc_opt = cache.get(key); - infiniopRoPEDescriptor_t desc = nullptr; + INFINIOP_WORKSPACE_TENSOR(workspace, RoPE, descriptor); + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(x_out), + graph::GraphTensor(x), + graph::GraphTensor(pos), + graph::GraphTensor(sin), + graph::GraphTensor(cos)}; +} - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( - context::getInfiniopHandle(device), &desc, - x_out->desc(), x->desc(), - pos->desc(), sin_cache->desc(), cos_cache->desc(), - infiniop_algo)); - cache.put(key, desc); - } else { - desc = *desc_opt; - } +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetRoPEWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); + INFINICORE_CHECK_ERROR( + infiniopRoPE( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->x_out->data(), + p->x->data(), + p->pos->data(), + p->sin->data(), + p->cos->data(), + context::getStream())); +} - // InfiniOP reads from x and writes to x_out (handles copying internally) - INFINICORE_CHECK_ERROR(infiniopRoPE( - desc, workspace->data(), workspace_size, - x_out->data(), x->data(), pos->data(), - sin_cache->data(), cos_cache->data(), context::getStream())); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - RoPE::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(RoPE, &plan, &run, &cleanup); } // namespace infinicore::op::rope_impl::infiniop diff --git a/src/infinicore/ops/swiglu/swiglu.cc b/src/infinicore/ops/swiglu/swiglu.cc index 5646180e7..8ee0682ad 100644 --- a/src/infinicore/ops/swiglu/swiglu.cc +++ b/src/infinicore/ops/swiglu/swiglu.cc @@ -1,37 +1,26 @@ #include "infinicore/ops/swiglu.hpp" - #include "../../utils.hpp" -#include - namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SwiGLU); -common::OpDispatcher &SwiGLU::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void SwiGLU::execute(Tensor c, Tensor a, Tensor b) { +SwiGLU::SwiGLU(Tensor c, const Tensor &a, const Tensor &b) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); - infinicore::context::setDevice(c->device()); - auto device_type = c->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No SwiGLU implementation found for device type: " + std::to_string(static_cast(device_type))); - } + INFINICORE_GRAPH_OP_DISPATCH(c->device().getType(), c, a, b); +} - func(c, a, b); +void SwiGLU::execute(Tensor c, const Tensor &a, const Tensor &b) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(SwiGLU, c, a, b); } -Tensor swiglu(Tensor a, Tensor b) { - Shape shape = a->shape(); - auto c = Tensor::empty(shape, a->dtype(), a->device()); +Tensor swiglu(const Tensor &a, const Tensor &b) { + auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); swiglu_(c, a, b); return c; } -void swiglu_(Tensor c, Tensor a, Tensor b) { +void swiglu_(Tensor c, const Tensor &a, const Tensor &b) { SwiGLU::execute(c, a, b); } + } // namespace infinicore::op diff --git a/src/infinicore/ops/swiglu/swiglu_infiniop.cc b/src/infinicore/ops/swiglu/swiglu_infiniop.cc index 4a963993b..fbb76b570 100644 --- a/src/infinicore/ops/swiglu/swiglu_infiniop.cc +++ b/src/infinicore/ops/swiglu/swiglu_infiniop.cc @@ -1,50 +1,55 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/swiglu.hpp" -#include + +#include "../infiniop_impl.hpp" namespace infinicore::op::swiglu_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopSwiGLUDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroySwiGLUDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor c, Tensor a, Tensor b) { - size_t seed = hash_combine(c, b, a); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopSwiGLUDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( - context::getInfiniopHandle(device), &desc, - c->desc(), a->desc(), b->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetSwiGLUWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopSwiGLU( - desc, workspace->data(), workspace_size, - c->data(), a->data(), b->data(), context::getStream())); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SwiGLU, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace; + graph::GraphTensor c; + graph::GraphTensor a; + graph::GraphTensor b; +}; + +void *plan(Tensor c, const Tensor &a, const Tensor &b) { + size_t key = hash_combine(c, a, b); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, SwiGLU, + key, c->desc(), a->desc(), b->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, SwiGLU, descriptor); + + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(c), + graph::GraphTensor(a), + graph::GraphTensor(b)}; +} + +void run(void *planned_meta) { + auto *p = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR( + infiniopSwiGLU( + p->descriptor->desc, + p->workspace->data(), + p->workspace->numel(), + p->c->data(), + p->a->data(), + p->b->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - SwiGLU::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SwiGLU, &plan, &run, &cleanup); } // namespace infinicore::op::swiglu_impl::infiniop diff --git a/src/infinicore/utils.hpp b/src/infinicore/utils.hpp index cf8e69789..fd0578a5e 100644 --- a/src/infinicore/utils.hpp +++ b/src/infinicore/utils.hpp @@ -49,6 +49,7 @@ inline struct SpdlogInitializer { + ":" + std::to_string(__LINE__) + "."); \ } \ } \ + infinicore::context::setDevice((FIRST___)->device()); \ } while (0) #define INFINICORE_ASSERT(CONDITION__) \ diff --git a/test/infinicore/graph/attention.py b/test/infinicore/graph/attention.py new file mode 100644 index 000000000..cae70dc04 --- /dev/null +++ b/test/infinicore/graph/attention.py @@ -0,0 +1,356 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework import BaseOperatorTest, GenericTestRunner, TensorSpec, TestCase +from framework.tensor import TensorInitializer + +import infinicore + +# Test cases format: (nlayers, batch_size, hidden_size, nhead, nkvhead, dim, seqlen, past_seqlen, max_seqlen) +_TEST_CASES_DATA = [ + (28, 1, 3584, 28, 28, 128, 1, 256, 512), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-4, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-4, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 1e-4, "rtol": 5e-2}, +} +_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16] + + +def parse_test_cases(): + cases = [] + for ( + nlayers, + batch_size, + hidden_size, + nhead, + nkvhead, + dim, + seqlen, + past_seqlen, + max_seqlen, + ) in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + hidden_states = TensorSpec.from_tensor( + (batch_size, seqlen, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + pos_ids = TensorSpec.from_tensor( + (batch_size, seqlen), + dtype=infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=0, + high=max_seqlen, + ) + k_cache = TensorSpec.from_tensor( + (nlayers, batch_size, nkvhead, max_seqlen, dim), + dtype=dtype, + scale=1e-1, + bias=-5e-2, + ) + v_cache = TensorSpec.from_tensor( + (nlayers, batch_size, nkvhead, max_seqlen, dim), + dtype=dtype, + scale=1e-1, + bias=-5e-2, + ) + q_proj_w = TensorSpec.from_tensor( + (nhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + k_proj_w = TensorSpec.from_tensor( + (nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + v_proj_w = TensorSpec.from_tensor( + (nkvhead * dim, hidden_size), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + o_proj_w = TensorSpec.from_tensor( + (hidden_size, nhead * dim), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + norm_w = TensorSpec.from_tensor( + (hidden_size,), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + sin_table = TensorSpec.from_tensor( + (max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + cos_table = TensorSpec.from_tensor( + (max_seqlen, dim // 2), dtype=dtype, scale=1e-1, bias=-5e-2 + ) + + # Out-of-place + cases.append( + TestCase( + inputs=[ + hidden_states, + pos_ids, + nhead, + nkvhead, + dim, + past_seqlen, + nlayers, + k_cache, + v_cache, + q_proj_w, + k_proj_w, + v_proj_w, + o_proj_w, + norm_w, + sin_table, + cos_table, + ], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="Graph", + ) + ) + + return cases + + +def torch_rope( + q: torch.Tensor, + k: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + pos_ids: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + q, k: [B, H, S, D] + sin, cos: [max_S, D//2] + pos_ids: [B, S] + """ + + def rotate_half(x: torch.Tensor) -> torch.Tensor: + # x: [..., head_dim] + x_even = x[..., 0::2] + x_odd = x[..., 1::2] + return torch.stack((-x_odd, x_even), dim=-1).flatten(-2) + + B, H, S, D = q.shape + assert D % 2 == 0 + + # Gather sin/cos by position + # -> [B, S, D//2] + sin = sin[pos_ids] + cos = cos[pos_ids] + + # Expand to broadcast over heads + # -> [B, 1, S, D//2] + sin = sin.unsqueeze(1) + cos = cos.unsqueeze(1) + + # Interleave to full dim + sin = torch.repeat_interleave(sin, 2, dim=-1) + cos = torch.repeat_interleave(cos, 2, dim=-1) + + # Apply RoPE + q_rot = (q * cos) + (rotate_half(q) * sin) + k_rot = (k * cos) + (rotate_half(k) * sin) + + return q_rot, k_rot + + +class OpTest(BaseOperatorTest): + """Test Operator Graph""" + + def __init__(self): + super().__init__("Graph") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator( + self, + hidden_states, + pos_ids, + nhead, + nkvhead, + dim, + past_seqlen, + nlayers, + k_cache, + v_cache, + q_proj_w, + k_proj_w, + v_proj_w, + o_proj_w, + norm_w, + sin_table, + cos_table, + **kwargs, + ): + B, S, D = hidden_states.shape + + for layer in range(nlayers): + # ---- RMSNorm ---- + var = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(var + 1e-5) * norm_w + + # ---- QKV projection ---- + q = hidden_states @ q_proj_w.T + k = hidden_states @ k_proj_w.T + v = hidden_states @ v_proj_w.T + + q = q.view(B, S, nhead, dim).transpose(1, 2) # [B,H,S,Dh] + k = k.view(B, S, nkvhead, dim).transpose(1, 2) + v = v.view(B, S, nkvhead, dim).transpose(1, 2) + + # ---- RoPE ---- + q, k = torch_rope( + q, + k, + sin_table, + cos_table, + pos_ids, + ) + + # ---- KV cache update ---- + k_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = k + v_cache[layer, :, :, past_seqlen : past_seqlen + S, :] = v + + K = k_cache[layer, :, :, 0 : past_seqlen + S, :] + V = v_cache[layer, :, :, 0 : past_seqlen + S, :] + + # ---- Scaled Dot Product Attention (fused) ---- + def scaled_dot_product_attention( + query, key, value, is_causal=False, enable_gqa=False + ) -> torch.Tensor: + S, L = query.size(-2), key.size(-2) + scale_factor = query.size(-1) ** -0.5 + attn_bias = torch.zeros(S, L, dtype=query.dtype, device=query.device) + if is_causal: + mask = torch.tril(attn_bias + 1, diagonal=-1).flip(dims=[-2, -1]) + attn_bias = torch.where(mask == 1, -torch.inf, 0.0) + + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave( + query.size(-3) // value.size(-3), -3 + ) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value + + attn_out = scaled_dot_product_attention( + q, + K, + V, + is_causal=True, + enable_gqa=True, + ) # [B,H,S,Dh] + + # ---- Output projection ---- + attn_out = attn_out.transpose(1, 2).contiguous() + attn_out = attn_out.view(B, S, nhead * dim) + + hidden_states = attn_out @ o_proj_w.T + + return hidden_states + + def infinicore_operator( + self, + hidden_states, + pos_ids, + nhead, + nkvhead, + dim, + past_seqlen, + nlayers, + k_cache, + v_cache, + q_proj_w, + k_proj_w, + v_proj_w, + o_proj_w, + norm_w, + sin_table, + cos_table, + **kwargs, + ): + """Record graph and run""" + input_hidden_states = hidden_states + B, S, D = input_hidden_states.shape + + infinicore.start_graph_recording() + for layer in range(nlayers): + hidden_states = infinicore.nn.functional.rms_norm( + hidden_states, norm_w.shape, norm_w, 1e-5 + ) + q = infinicore.nn.functional.linear(hidden_states, q_proj_w) + k = infinicore.nn.functional.linear(hidden_states, k_proj_w) + v = infinicore.nn.functional.linear(hidden_states, v_proj_w) + + q = q.view((B, S, nhead, dim)) + k = k.view((B, S, nkvhead, dim)) + v = v.view((B, S, nkvhead, dim)) + q = infinicore.nn.functional.rope( + q, + pos_ids, + sin_table, + cos_table, + infinicore.nn.functional.RopeAlgo.GPT_J, + ) + k = infinicore.nn.functional.rope( + k, + pos_ids, + sin_table, + cos_table, + infinicore.nn.functional.RopeAlgo.GPT_J, + ) + + # [B, KVH, total_len, D] + full_k = ( + k_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S) + ) + full_v = ( + v_cache.narrow(0, layer, 1).squeeze(0).narrow(2, 0, past_seqlen + S) + ) + full_k.narrow(2, past_seqlen, S).copy_(k.permute((0, 2, 1, 3))) + full_v.narrow(2, past_seqlen, S).copy_(v.permute((0, 2, 1, 3))) + + G = nhead // nkvhead + L = past_seqlen + S + + full_q = ( + q.permute((0, 2, 1, 3)).contiguous().view((B * nkvhead, G * S, dim)) + ) + full_k = full_k.view((B * nkvhead, L, dim)) + full_v = full_v.view((B * nkvhead, L, dim)) + + attn_score = infinicore.matmul( + full_q, full_k.permute((0, 2, 1)), alpha=dim**-0.5 + ) + # [B * H, S, total_len] + attn_score = attn_score.view((B * nhead, S, L)) + infinicore.nn.functional.causal_softmax(attn_score, out=attn_score) + attn_out = infinicore.matmul(attn_score, full_v) + attn_out = ( + attn_out.view((B, nhead, S, dim)) + .permute((0, 2, 1, 3)) + .contiguous() + .view((B, S, nhead * dim)) + ) + hidden_states = infinicore.nn.functional.linear(attn_out, o_proj_w) + + op_graph = infinicore.stop_graph_recording() + + op_graph.run() + return hidden_states + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/graph/graph.py b/test/infinicore/graph/graph.py deleted file mode 100644 index 2f8927110..000000000 --- a/test/infinicore/graph/graph.py +++ /dev/null @@ -1,85 +0,0 @@ -import sys -import os - -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) - -import torch -import infinicore -from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner - -# Test cases format: (in_shape, proj_w_shape) -_TEST_CASES_DATA = [ - ((32, 4096), (4096, 4096)), -] - -_TOLERANCE_MAP = { - infinicore.float16: {"atol": 0, "rtol": 1e-2}, - infinicore.float32: {"atol": 1e-4, "rtol": 1e-3}, - infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, -} -_TENSOR_DTYPES = [infinicore.float16, infinicore.float32, infinicore.bfloat16] - - -def parse_test_cases(): - cases = [] - for in_shape, proj_w_shape in _TEST_CASES_DATA: - for dtype in _TENSOR_DTYPES: - tol = _TOLERANCE_MAP[dtype] - in_spec = TensorSpec.from_tensor(in_shape, dtype=dtype) - proj_w_spec = TensorSpec.from_tensor(proj_w_shape, dtype=dtype) - temp_spec = TensorSpec.from_tensor(in_shape, dtype=dtype) - - # Out-of-place - cases.append( - TestCase( - inputs=[in_spec, proj_w_spec, temp_spec], - kwargs={}, - output_spec=None, - comparison_target=None, - tolerance=tol, - description="Graph", - ) - ) - - return cases - - -class OpTest(BaseOperatorTest): - """Test Operator Graph""" - - def __init__(self): - super().__init__("Graph") - - def get_test_cases(self): - return parse_test_cases() - - def torch_operator(self, *args, **kwargs): - a = args[0] - b = args[1] - - return torch.matmul(a, b) - - def infinicore_operator(self, *args, **kwargs): - """Record graph and run""" - a = args[0] - b = args[1] - temp_a = args[2] - - infinicore.start_graph_recording() - c = infinicore.matmul(temp_a, b) - op_graph = infinicore.stop_graph_recording() - - temp_a.copy_(a) - op_graph.run() - - return c - - -def main(): - """Main entry point""" - runner = GenericTestRunner(OpTest) - runner.run_and_exit() - - -if __name__ == "__main__": - main()