Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions include/infinicore/graph/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ class GraphTensor : public Tensor {
};

class GraphOperator {
public:
virtual void run() const = 0;
virtual ~GraphOperator() = default;
};

class DispatchableGraphOperator : public GraphOperator {
public:
void run() const;
~GraphOperator();
void run() const override;
~DispatchableGraphOperator() override;

protected:
using run_schema = void (*)(void *);
Expand All @@ -45,7 +50,7 @@ class Graph {
} // namespace infinicore::graph

#define INFINICORE_GRAPH_OP_CLASS(__OP_NAME__, ...) \
class __OP_NAME__ : public graph::GraphOperator { \
class __OP_NAME__ : public graph::DispatchableGraphOperator { \
public: \
using schema = void (*)(__VA_ARGS__); \
using plan_schema = void *(*)(__VA_ARGS__); \
Expand Down Expand Up @@ -75,12 +80,12 @@ class Graph {
runner_ = run_dispatcher().lookup(__DEVICE_TYPE__); \
deleter_ = cleanup_dispatcher().lookup(__DEVICE_TYPE__);

#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(op); \
} else { \
op->run(); \
#define INFINICORE_GRAPH_OP_RECORD_OR_RUN(__OP_NAME__, ...) \
auto ___op = std::make_shared<__OP_NAME__>(__VA_ARGS__); \
if (context::isGraphRecording()) { \
context::addGraphOperator(___op); \
} else { \
___op->run(); \
}

#define INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(__OP_NAME__, __PLAN_F__, __RUN_F__, __CLEANUP_F__) \
Expand Down
15 changes: 6 additions & 9 deletions include/infinicore/ops/add.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Add {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor add(Tensor a, Tensor b);
void add_(Tensor c, Tensor a, Tensor b);
Tensor operator+(Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Add, Tensor, const Tensor &, const Tensor &);

Tensor add(const Tensor &a, const Tensor &b);
void add_(Tensor c, const Tensor &a, const Tensor &b);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/causal_softmax.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class CausalSoftmax {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor output, Tensor input);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor causal_softmax(Tensor input);
void causal_softmax_(Tensor output, Tensor input);
INFINICORE_GRAPH_OP_CLASS(CausalSoftmax, Tensor, const Tensor &);

Tensor causal_softmax(const Tensor &input);
void causal_softmax_(Tensor output, const Tensor &input);

} // namespace infinicore::op
24 changes: 24 additions & 0 deletions include/infinicore/ops/distributed/allreduce.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once

#include "../../device.hpp"
#include "../../graph/graph.hpp"
#include "../common/op.hpp"

#include <infiniccl.h>

namespace infinicore::op::distributed {
class AllReduce : public graph::GraphOperator {
public:
AllReduce(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
~AllReduce();
void run() const override;
static void execute(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);

private:
void *planned_meta_;
};

Tensor allreduce(const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);
void allreduce_(Tensor output, const Tensor &input, infinicclReduceOp_t op, infinicclComm_t communicator);

} // namespace infinicore::op::distributed
6 changes: 3 additions & 3 deletions include/infinicore/ops/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, Tensor, Tensor, float, float);
INFINICORE_GRAPH_OP_CLASS(Gemm, Tensor, const Tensor &, const Tensor &, float, float);

Tensor gemm(Tensor a, Tensor b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, Tensor a, Tensor b, float alpha, float beta);
Tensor gemm(const Tensor &a, const Tensor &b, float alpha = 1.0f, float beta = 0.0f);
void gemm_(Tensor c, const Tensor &a, const Tensor &b, float alpha, float beta);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/mul.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Mul {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor mul(Tensor a, Tensor b);
void mul_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(Mul, Tensor, const Tensor &, const Tensor &);

Tensor mul(const Tensor &a, const Tensor &b);
void mul_(Tensor c, const Tensor &a, const Tensor &b);

} // namespace infinicore::op
18 changes: 10 additions & 8 deletions include/infinicore/ops/paged_attention.hpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, const Tensor &, std::optional<Tensor>, float);

Tensor paged_attention(const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);

void paged_attention_(Tensor out, const Tensor &q, const Tensor &k_cache, const Tensor &v_cache,
const Tensor &block_tables, const Tensor &kv_lens,
std::optional<Tensor> alibi_slopes, float scale);

Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
10 changes: 3 additions & 7 deletions include/infinicore/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};
INFINICORE_GRAPH_OP_CLASS(PagedCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &);

void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, const Tensor &k, const Tensor &v, const Tensor &slot_mapping);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/rearrange.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class Rearrange {
public:
using schema = void (*)(Tensor, Tensor);
static void execute(Tensor y, Tensor x);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor rearrange(Tensor x);
void rearrange_(Tensor y, Tensor x);
INFINICORE_GRAPH_OP_CLASS(Rearrange, Tensor, const Tensor &);

Tensor rearrange(const Tensor &x);
void rearrange_(Tensor y, const Tensor &x);

} // namespace infinicore::op
14 changes: 6 additions & 8 deletions include/infinicore/ops/rms_norm.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class RMSNorm {
public:
using schema = void (*)(Tensor, Tensor, Tensor, float);
static void execute(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor rms_norm(Tensor x, Tensor weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, Tensor x, Tensor weight, float epsilon = 1e-5f);
INFINICORE_GRAPH_OP_CLASS(RMSNorm, Tensor, const Tensor &, const Tensor &, float);

Tensor rms_norm(const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);
void rms_norm_(Tensor y, const Tensor &x, const Tensor &weight, float epsilon = 1e-5f);

} // namespace infinicore::op
27 changes: 17 additions & 10 deletions include/infinicore/ops/rope.hpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,28 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "../nn/rope.hpp"
#include "../tensor.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class RoPE {
public:
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
static common::OpDispatcher<schema> &dispatcher();
};

// Internal function
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
INFINICORE_GRAPH_OP_CLASS(RoPE, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);

// Internal
void rope_(Tensor x_out,
const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);

// Public API
Tensor rope(const Tensor &x,
const Tensor &pos,
const Tensor &sin_table,
const Tensor &cos_table,
infinicore::nn::RoPE::Algo algo);

// Public API that uses infinicore::nn::RoPE::Algo
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
} // namespace infinicore::op
15 changes: 7 additions & 8 deletions include/infinicore/ops/swiglu.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "../tensor.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class SwiGLU {
public:
using schema = void (*)(Tensor, Tensor, Tensor);
static void execute(Tensor c, Tensor a, Tensor b);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor swiglu(Tensor a, Tensor b);
void swiglu_(Tensor c, Tensor a, Tensor b);
INFINICORE_GRAPH_OP_CLASS(SwiGLU, Tensor, const Tensor &, const Tensor &);

Tensor swiglu(const Tensor &a, const Tensor &b);
void swiglu_(Tensor c, const Tensor &a, const Tensor &b);

} // namespace infinicore::op
4 changes: 2 additions & 2 deletions src/infinicore/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ GraphTensor::GraphTensor(const Tensor &tensor) : Tensor(tensor->to_blob()) {
* GraphOperator
* ========================= */

void GraphOperator::run() const {
void DispatchableGraphOperator::run() const {
runner_(planned_meta_);
}

GraphOperator::~GraphOperator() {
DispatchableGraphOperator::~DispatchableGraphOperator() {
if (deleter_) {
deleter_(&planned_meta_);
}
Expand Down
19 changes: 2 additions & 17 deletions src/infinicore/nn/linear.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "infinicore/nn/linear.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops/distributed/allreduce.hpp"
#include "infinicore/ops/linear.hpp"
#include <optional>
#include <spdlog/spdlog.h>
Expand Down Expand Up @@ -102,9 +103,6 @@ ColumnParallelLinear::ColumnParallelLinear(size_t in_features, size_t out_featur
} else {
bias_ = Parameter(); // Default constructed empty parameter
}

// SPDLOG_DEBUG("Created ColumnParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
}

Tensor ColumnParallelLinear::forward(Tensor &input) const {
Expand Down Expand Up @@ -138,26 +136,13 @@ RowParallelLinear::RowParallelLinear(size_t in_features, size_t out_features, bo
} else {
bias_ = Parameter(); // Default constructed empty parameter
}

// SPDLOG_DEBUG("Created RowParallelLinear module: in_features={}, out_features={}, bias={}, dtype={}",
// in_features, out_features, bias, static_cast<int>(dtype_));
}

Tensor RowParallelLinear::forward(Tensor &input) const {
auto output = BaseLinear::forward(input);

if ((tp_size_ > 1) && (communicator_ != nullptr)) {

Size count = output->numel();
DataType type = output->dtype();

infinirtStream_t stream = infinicore::context::getStream();

INFINICORE_CHECK_ERROR(infinicclAllReduce(output->data(), output->data(), count, static_cast<infiniDtype_t>(static_cast<int>(type)),
INFINICCL_SUM, communicator_, stream));
INFINICORE_CHECK_ERROR(infinirtStreamSynchronize(stream));

// RUN_INFINI(infinirtStreamSynchronize(stream));
op::distributed::allreduce_(output, output, INFINICCL_SUM, communicator_);
}
return output;
}
Expand Down
Loading