Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/infinicore/ops/bi_attention.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"

namespace infinicore::op {
class BiAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, size_t);
static void execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor bi_attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
void bi_attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "infiniop/ops/topkrouter.h"
#include "infiniop/ops/topksoftmax.h"
#include "infiniop/ops/zeros.h"
#include "infiniop/ops/bi_attention.h"
#include "infiniop/tensor_descriptor.h"

#endif // __INFINIOP_API_H__
34 changes: 34 additions & 0 deletions include/infiniop/ops/bi_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef __INFINIOP_BI_ATTENTION_API_H__
#define __INFINIOP_BI_ATTENTION_API_H__

#include "../operator_descriptor.h"
#include "gemm.h"
#include "swiglu.h"

typedef struct InfiniopDescriptor *infiniopBiAttentionDescriptor_t;

__C __export infiniStatus_t infiniopCreateBiAttentionDescriptor(infiniopHandle_t handle,
infiniopBiAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
size_t pos);

__C __export infiniStatus_t infiniopGetBiAttentionWorkspaceSize(infiniopBiAttentionDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopBiAttention(infiniopBiAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k,
const void *v,
void *k_cache,
void *v_cache,
void *stream);

__C __export infiniStatus_t infiniopDestroyBiAttentionDescriptor(infiniopBiAttentionDescriptor_t desc);
#endif
28 changes: 28 additions & 0 deletions python/infinicore/ops/bi_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def bi_attention(q, k, v, k_cache, v_cache, pos, *, out=None):
if out is None:
return Tensor(
_infinicore.attention(
q._underlying,
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
pos,
)
)

_infinicore.bi_attention_(
out._underlying,
q._underlying,
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
pos,
)

return out
31 changes: 31 additions & 0 deletions src/infinicore/ops/bi_attention/bi_attention.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "infinicore/ops/bi_attention.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

common::OpDispatcher<BiAttention::schema> &BiAttention::dispatcher() {
static common::OpDispatcher<BiAttention::schema> dispatcher_;
return dispatcher_;
};

void BiAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v, k_cache, v_cache);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k, v, k_cache, v_cache, pos);
}

Tensor bi_attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
size_t n_q_head = q->shape()[0];
size_t seq_len = q->shape()[1];
size_t head_dim = q->shape()[2];
Shape shape = {seq_len, n_q_head, head_dim};
auto out = Tensor::empty(shape, q->dtype(), q->device());
bi_attention_(out, q, k, v, k_cache, v_cache, pos);
return out;
}

void bi_attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
BiAttention::execute(out, q, k, v, k_cache, v_cache, pos);
}

} // namespace infinicore::op
52 changes: 52 additions & 0 deletions src/infinicore/ops/bi_attention/bi_attention_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/bi_attention.hpp"
#include "infinicore/ops/common/cache.hpp"
#include <infiniop.h>

namespace infinicore::op::bi_attention_impl::infiniop {

thread_local common::OpCache<size_t, infiniopBiAttentionDescriptor_t> caches(
100, // capacity
[](infiniopBiAttentionDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyBiAttentionDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
size_t seed = hash_combine(out, q, k, v, k_cache, v_cache, pos);

auto device = context::getDevice();
auto &cache = caches.getCache(device);

auto desc_opt = cache.get(seed);
infiniopBiAttentionDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateBiAttentionDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetBiAttentionWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopBiAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k->data(), v->data(),
k_cache->data(), v_cache->data(), context::getStream()));
}

static bool registered = []() {
BiAttention::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::attention_impl::infiniop
Empty file.
6 changes: 3 additions & 3 deletions src/infiniop/ops/attention/attention.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#ifndef ATTENTION_H
#define ATTENTION_H
#ifndef BI_ATTENTION_H
#define BI_ATTENTION_H

#include "../../operator.h"
#include "info.h"

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
namespace op::bi_attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
Expand Down
37 changes: 37 additions & 0 deletions src/infiniop/ops/bi_attention/bi_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef ATTENTION_H
#define ATTENTION_H

#include "../../operator.h"
#include "info.h"

#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
}; \
}

#endif // ATTENTION_H
Loading