diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index a7249ec9d..a156a8176 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -4,12 +4,15 @@ #include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" #include "ops/random_sample.hpp" +#include "ops/random_sample_batched.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp new file mode 100644 index 000000000..24e33cfb6 --- /dev/null +++ b/include/infinicore/ops/flash_attention.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool); + +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); +} // namespace infinicore::op diff --git a/include/infinicore/ops/kv_caching.hpp b/include/infinicore/ops/kv_caching.hpp new file mode 100644 index 000000000..3a70c2824 --- /dev/null +++ b/include/infinicore/ops/kv_caching.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &); + +void kv_caching_(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths); +} // namespace infinicore::op diff --git a/include/infinicore/ops/random_sample_batched.hpp b/include/infinicore/ops/random_sample_batched.hpp new file mode 100644 index 000000000..8906bc12b --- /dev/null +++ b/include/infinicore/ops/random_sample_batched.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class RandomSampleBatched { +public: + using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int); + static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); + static common::OpDispatcher &dispatcher(); +}; + +// Out-of-place API +Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); +// In-place API +void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..246180e65 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,8 +9,10 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/flash_attention.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/kv_caching.h" #include "infiniop/ops/layer_norm.h" #include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/lp_norm.h" @@ -20,6 +22,7 @@ #include "infiniop/ops/paged_attention_prefill.h" #include "infiniop/ops/paged_caching.h" #include "infiniop/ops/random_sample.h" +#include "infiniop/ops/random_sample_batched.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h new file mode 100644 index 000000000..5ea71335b --- /dev/null +++ b/include/infiniop/ops/flash_attention.h @@ -0,0 +1,36 @@ +#ifndef __INFINIOP_FLASH_ATTENTION_API_H__ +#define __INFINIOP_FLASH_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + float scale, + char is_causal); + +__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *total_kv_len, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc); +#endif diff --git a/include/infiniop/ops/kv_caching.h b/include/infiniop/ops/kv_caching.h new file mode 100644 index 000000000..e6efa48b3 --- /dev/null +++ b/include/infiniop/ops/kv_caching.h @@ -0,0 +1,31 @@ +#ifndef __INFINIOP_KV_CACHING_API_H__ +#define __INFINIOP_KV_CACHING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths); + +__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream); + +__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/random_sample.h b/include/infiniop/ops/random_sample.h index 1c242d7ba..bb2b15959 100644 --- a/include/infiniop/ops/random_sample.h +++ b/include/infiniop/ops/random_sample.h @@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize( infiniopRandomSampleDescriptor_t desc, size_t *size); -__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor( - infiniopHandle_t handle, - infiniopRandomSampleDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t result, - infiniopTensorDescriptor_t probs); - __C __export infiniStatus_t infiniopRandomSample( infiniopRandomSampleDescriptor_t desc, void *workspace, diff --git a/include/infiniop/ops/random_sample_batched.h b/include/infiniop/ops/random_sample_batched.h new file mode 100644 index 000000000..4512e7dcb --- /dev/null +++ b/include/infiniop/ops/random_sample_batched.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__ +#define __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRandomSampleBatchedDescriptor_t; + +__C __export infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor( + infiniopHandle_t handle, + infiniopRandomSampleBatchedDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +__C __export infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize( + infiniopRandomSampleBatchedDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopRandomSampleBatched( + infiniopRandomSampleBatchedDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size, + void *stream); + +__C __export infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor( + infiniopRandomSampleBatchedDescriptor_t desc); + +#endif diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..845bbcc0a 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,7 @@ from infinicore.ops.add import add from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ from infinicore.ops.attention import attention +from infinicore.ops.kv_caching import kv_caching from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -115,6 +116,7 @@ "add_rms_norm", "add_rms_norm_", "attention", + "kv_caching", "matmul", "mul", "narrow", diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..d34490365 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,20 +1,24 @@ from .causal_softmax import causal_softmax from .embedding import embedding +from .flash_attention import flash_attention from .linear import linear from .random_sample import random_sample from .rms_norm import rms_norm from .rope import RopeAlgo, rope +from .scaled_dot_product_attention import scaled_dot_product_attention from .silu import silu from .swiglu import swiglu __all__ = [ "causal_softmax", + "embedding", + "flash_attention", + "linear", "random_sample", "rms_norm", + "rope", + "scaled_dot_product_attention", "silu", "swiglu", - "linear", - "embedding", - "rope", "RopeAlgo", ] diff --git a/python/infinicore/nn/functional/flash_attention.py b/python/infinicore/nn/functional/flash_attention.py new file mode 100644 index 000000000..8f42e865f --- /dev/null +++ b/python/infinicore/nn/functional/flash_attention.py @@ -0,0 +1,34 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def flash_attention( + query, + key, + value, + total_kv_len, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, + key._underlying, + value._underlying, + total_kv_len._underlying, + scale, + is_causal, + ) + ) diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py new file mode 100644 index 000000000..0b780e562 --- /dev/null +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -0,0 +1,35 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + raise NotImplementedError("Scaled Dot Product Attention is not yet supported.") + + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, + key._underlying, + value._underlying, + key.shape[-2], + scale, + is_causal, + ) + ) diff --git a/python/infinicore/ops/kv_caching.py b/python/infinicore/ops/kv_caching.py new file mode 100644 index 000000000..b34f2346e --- /dev/null +++ b/python/infinicore/ops/kv_caching.py @@ -0,0 +1,13 @@ +from infinicore.lib import _infinicore + + +def kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + _infinicore.kv_caching_( + k_cache._underlying, + v_cache._underlying, + k._underlying, + v._underlying, + past_kv_lengths._underlying, + ) + + return k_cache, v_cache diff --git a/scripts/build_ntops.py b/scripts/build_ntops.py index 1499b6bf8..601249615 100644 --- a/scripts/build_ntops.py +++ b/scripts/build_ntops.py @@ -1,3 +1,4 @@ +import concurrent.futures import importlib import pathlib @@ -11,16 +12,32 @@ def _find_and_build_ops(): ops_path = SRC_DIR_PATH / "infiniop" / "ops" - for op_dir in ops_path.iterdir(): - ninetoothed_path = op_dir / "ninetoothed" + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - if ninetoothed_path.is_dir(): - module_path = ninetoothed_path / "build" - relative_path = module_path.relative_to(SRC_DIR_PATH) - import_name = ".".join(relative_path.parts) - module = importlib.import_module(import_name) + for op_dir in ops_path.iterdir(): + ninetoothed_path = op_dir / "ninetoothed" - module.build() + if not ninetoothed_path.is_dir(): + continue + + build_file = ninetoothed_path / "build.py" + if not build_file.exists(): + continue + + futures.append(executor.submit(_build, ninetoothed_path)) + + for future in concurrent.futures.as_completed(futures): + future.result() + + +def _build(ninetoothed_path): + module_path = ninetoothed_path / "build" + relative_path = module_path.relative_to(SRC_DIR_PATH) + import_name = ".".join(relative_path.parts) + module = importlib.import_module(import_name) + + module.build() if __name__ == "__main__": diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc new file mode 100644 index 000000000..21cd56010 --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -0,0 +1,31 @@ +#include "infinicore/ops/flash_attention.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention); + +FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k, v, total_kv_len, scale, is_causal); +} + +void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal); +} + +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + Shape shape = q->shape(); + int idx = shape.size() - 1; + shape[idx] = v->shape()[idx]; + auto out = Tensor::empty(shape, q->dtype(), q->device()); + flash_attention_(out, q, k, v, total_kv_len, scale, is_causal); + return out; +} + +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc new file mode 100644 index 000000000..f5207f0ee --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -0,0 +1,55 @@ +#include "../../utils.hpp" +#include "../infiniop_impl.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/flash_attention.hpp" +#include + +namespace infinicore::op::flash_attention_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, q, k, v, total_kv_len; + float scale; + bool is_causal; +}; + +void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { + size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, FlashAttention, + seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal); + + INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(total_kv_len), scale, is_causal}; + + return planned; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopFlashAttention( + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), + planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FlashAttention, &plan, &run, &cleanup); + +} // namespace infinicore::op::flash_attention_impl::infiniop diff --git a/src/infinicore/ops/kv_caching/kv_caching.cc b/src/infinicore/ops/kv_caching/kv_caching.cc new file mode 100644 index 000000000..0110f7973 --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching.cc @@ -0,0 +1,42 @@ +#include "infinicore/ops/kv_caching.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(KVCaching); + +KVCaching::KVCaching(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths); + INFINICORE_GRAPH_OP_DISPATCH(k_cache->device().getType(), + k_cache, + v_cache, + k, + v, + past_kv_lengths); +} + +void KVCaching::execute(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(KVCaching, + k_cache, + v_cache, + k, + v, + past_kv_lengths); +} + +void kv_caching_(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc new file mode 100644 index 000000000..53ea5f0ae --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc @@ -0,0 +1,60 @@ +#include "../infiniop_impl.hpp" +#include "infinicore/ops/kv_caching.hpp" + +namespace infinicore::op::kv_caching_impl::infiniop { + +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, KVCaching, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, k_cache, v_cache, k, v, past_kv_lengths; +}; + +void *plan(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, KVCaching, + seed, k_cache->desc(), v_cache->desc(), + k->desc(), v->desc(), past_kv_lengths->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, KVCaching, descriptor); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(past_kv_lengths)}; + + return planned; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); + + INFINICORE_CHECK_ERROR(infiniopKVCaching( + planned->descriptor->desc, + nullptr, 0, + planned->k_cache->data(), + planned->v_cache->data(), + planned->k->data(), + planned->v->data(), + planned->past_kv_lengths->data(), + context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(KVCaching, &plan, &run, cleanup); + +} // namespace infinicore::op::kv_caching_impl::infiniop diff --git a/src/infinicore/ops/random_sample_batched/random_sample_batched.cc b/src/infinicore/ops/random_sample_batched/random_sample_batched.cc new file mode 100644 index 000000000..a02635f66 --- /dev/null +++ b/src/infinicore/ops/random_sample_batched/random_sample_batched.cc @@ -0,0 +1,54 @@ +#include "infinicore/ops/random_sample_batched.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &RandomSampleBatched::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void RandomSampleBatched::execute( + Tensor result, + Tensor probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, probs); + infinicore::context::setDevice(result->device()); + auto device_type = result->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No RandomSampleBatched implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(result, probs, random_val, topp, topk, temperature, batch_size); +} + +Tensor random_sample_batched( + Tensor logits, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + Shape shape = logits->shape(); + auto result = Tensor::empty(shape, DataType::I32, logits->device()); + random_sample_batched_(result, logits, random_val, topp, topk, temperature, batch_size); + return result; +} +void random_sample_batched_( + Tensor result, + Tensor logits, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + RandomSampleBatched::execute(result, logits, random_val, topp, topk, temperature, batch_size); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc b/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc new file mode 100644 index 000000000..2916c0b2a --- /dev/null +++ b/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc @@ -0,0 +1,63 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/random_sample_batched.hpp" +#include + +namespace infinicore::op::random_sample_batched_impl::infiniop_backend { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopRandomSampleBatchedDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleBatchedDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate( + Tensor result, + Tensor probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + size_t seed = hash_combine(result, probs, batch_size); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopRandomSampleBatchedDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleBatchedDescriptor( + context::getInfiniopHandle(device), &desc, + result->desc(), probs->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRandomSampleBatchedWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRandomSampleBatched( + desc, + workspace->data(), workspace_size, + result->data(), probs->data(), + random_val, topp, topk, temperature, + batch_size, + context::getStream())); +} + +} // namespace infinicore::op::random_sample_batched_impl::infiniop_backend + +namespace infinicore::op { +static bool registered = []() { + RandomSampleBatched::dispatcher().registerAll(&random_sample_batched_impl::infiniop_backend::calculate, false); + return true; +}(); +} // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3d6ebe79a..c3b781050 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -7,6 +7,8 @@ #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -29,19 +31,21 @@ inline void bind(py::module &m) { bind_add_rms_norm(m); bind_attention(m); bind_causal_softmax(m); - bind_random_sample(m); + bind_embedding(m); + bind_flash_attention(m); + bind_kv_caching(m); bind_linear(m); bind_matmul(m); bind_mul(m); bind_paged_attention(m); bind_paged_attention_prefill(m); bind_paged_caching(m); + bind_random_sample(m); bind_rearrange(m); bind_rms_norm(m); + bind_rope(m); bind_silu(m); bind_swiglu(m); - bind_rope(m); - bind_embedding(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/flash_attention.hpp b/src/infinicore/pybind11/ops/flash_attention.hpp new file mode 100644 index 000000000..6e3766796 --- /dev/null +++ b/src/infinicore/pybind11/ops/flash_attention.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include + +#include "infinicore/ops/flash_attention.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_flash_attention(py::module &m) { + m.def("flash_attention", + &op::flash_attention, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("total_kv_len"), + py::arg("scale"), + py::arg("is_causal")); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/kv_caching.hpp b/src/infinicore/pybind11/ops/kv_caching.hpp new file mode 100644 index 000000000..2864312b2 --- /dev/null +++ b/src/infinicore/pybind11/ops/kv_caching.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include "infinicore/ops/kv_caching.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_kv_caching(py::module &m) { + m.def("kv_caching_", + &op::kv_caching_, + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("k"), + py::arg("v"), + py::arg("past_kv_lengths"), + R"doc(In-place Key-Value Caching. + +Updates the KV cache in-place with new key and value tensors. + +Args: + k_cache: Key cache tensor to update in-place + v_cache: Value cache tensor to update in-place + k: New key tensor to append + v: New value tensor to append + past_kv_lengths: Tensor containing current sequence lengths for each batch +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ninetoothed/build.py b/src/infiniop/ninetoothed/build.py index aea421b7f..153e6b9f5 100644 --- a/src/infiniop/ninetoothed/build.py +++ b/src/infiniop/ninetoothed/build.py @@ -1,3 +1,4 @@ +import concurrent.futures import functools import inspect import itertools @@ -16,40 +17,28 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): headers = [] all_param_names = [] + combinations = [] launches = [] - for combination in _generate_param_value_combinations(constexpr_param_grid): - arrangement, application, tensors = premake(**combination) + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - for param_name, param_value in combination.items(): - if isinstance(param_value, str): - combination[param_name] = ( - f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" - ) + for combination in tuple( + _generate_param_value_combinations(constexpr_param_grid) + ): + future = executor.submit( + _make, premake, combination, caller, op_name, output_dir + ) - combination = {f"{name}_": value for name, value in combination.items()} + futures.append(future) - kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + for future in concurrent.futures.as_completed(futures): + header, param_names, combination, launch = future.result() - ninetoothed.make( - arrangement, - application, - tensors, - caller=caller, - kernel_name=kernel_name, - output_dir=output_dir, - ) - - header = output_dir / f"{kernel_name}.h" - param_names = ("stream",) + tuple( - inspect.signature(application).parameters.keys() - ) - launch = f""" if ({_generate_condition(combination)}) - return launch_{kernel_name}({", ".join(param_names)});""" - - headers.append(header) - all_param_names.append(param_names) - launches.append(launch) + headers.append(header) + all_param_names.append(param_names) + combinations.append(combination) + launches.append(launch) includes = "\n".join(f'#include "{header}"' for header in headers) @@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): "NineToothedStream", ] + ["NineToothedTensor" for _ in range(len(param_names) - 1)] - for param_name in combination: + for param_name in functools.reduce(lambda x, y: x | y, combinations, {}): param_names.append(param_name) param_types.append("int") @@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): (BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content) +def _make(premake, combination, caller, op_name, output_dir): + arrangement, application, tensors = premake(**combination) + + for param_name, param_value in combination.items(): + if isinstance(param_value, str): + combination[param_name] = ( + f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" + ) + + combination = {f"{name}_": value for name, value in combination.items()} + + kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + + ninetoothed.make( + arrangement, + application, + tensors, + caller=caller, + kernel_name=kernel_name, + output_dir=output_dir, + ) + + header = output_dir / f"{kernel_name}.h" + param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys()) + launch = f""" if ({_generate_condition(combination)}) + return launch_{kernel_name}({", ".join(param_names)});""" + + return header, param_names, combination, launch + + def _generate_condition(combination): return " && ".join(f"{param} == {value}" for param, value in combination.items()) diff --git a/src/infiniop/ninetoothed/utils.h b/src/infiniop/ninetoothed/utils.h new file mode 100644 index 000000000..1b7d1fe3a --- /dev/null +++ b/src/infiniop/ninetoothed/utils.h @@ -0,0 +1,75 @@ +#ifndef __NINETOOTHED_UTILS__ +#define __NINETOOTHED_UTILS__ + +#include +#include +#include +#include + +namespace ninetoothed { + +template +class Tensor { +public: + using Data = decltype(NineToothedTensor::data); + + using Size = std::remove_pointer_t; + + using Stride = std::remove_pointer_t; + + template + Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {} + + Tensor(const void *data, std::initializer_list shape, std::initializer_list strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} + + Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {} + + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} + + operator NineToothedTensor() { return {const_cast(data_), shape_.data(), strides_.data()}; } + + template + Tensor expand(const Shape &sizes) const { + auto new_ndim{sizes.size()}; + + decltype(shape_) shape(new_ndim, 1); + decltype(strides_) strides(new_ndim, 0); + + auto num_new_dims{new_ndim - ndim_}; + + for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) { + shape[dim + num_new_dims] = shape_[dim]; + strides[dim + num_new_dims] = strides_[dim]; + } + + for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) { + if (sizes[dim] == std::numeric_limits>::max() || shape[dim] != 1) { + continue; + } + + shape[dim] = sizes[dim]; + strides[dim] = 0; + } + + return {data_, shape, strides}; + } + + Tensor expand_as(const Tensor &other) const { + return expand(other.shape_); + } + +private: + const void *data_{nullptr}; + + std::vector shape_; + + std::vector strides_; + + Size ndim_{0}; + + T value_{0}; +}; + +} // namespace ninetoothed + +#endif diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py new file mode 100644 index 000000000..4455e1ea6 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -0,0 +1,35 @@ +import ninetoothed +from . import flash_attention +from .flash_attention import CausalVariant + +import infiniop.ninetoothed.build + + +def build(): + with_kv_cache_values = (0,) + emb_dim_values = (16, 32, 64, 128, 256) + is_causal_values = (0, 1) + with_attn_mask_values = (0,) + causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT) + dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32) + block_size_m_values = (256,) + block_size_n_values = (64,) + + constexpr_param_grid = { + "with_kv_cache": with_kv_cache_values, + "emb_dim": emb_dim_values, + "is_causal": is_causal_values, + "with_attn_mask": with_attn_mask_values, + "causal_variant": causal_variant_values, + "dtype": dtype_values, + "block_size_m": block_size_m_values, + "block_size_n": block_size_n_values, + } + + infiniop.ninetoothed.build.build( + flash_attention.premake, + constexpr_param_grid, + caller="cuda", + op_name="flash_attention", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h new file mode 100644 index 000000000..0a6e9c1f8 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -0,0 +1,147 @@ +#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__ +#define __FLASH_ATTENTION_DESCRIPTOR_H__ + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/flash_attention.h" +#include "../../../ninetoothed/utils.h" + +namespace op::flash_attention::ninetoothed { + +class Descriptor final : public InfiniopDescriptor { +public: + Descriptor(infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + double scale, + char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, + _query_shape{q_desc->shape()}, + _query_strides{q_desc->strides()}, + _key_shape{k_desc->shape()}, + _key_strides{k_desc->strides()}, + _value_shape{v_desc->shape()}, + _value_strides{v_desc->strides()}, + _total_kv_shape{total_kv_len->shape()}, + _total_kv_strides{total_kv_len->strides()}, + _output_strides{out_desc->strides()}, + _dtype{q_desc->dtype()}, + _scale{scale}, + _is_causal{is_causal} { + } + + ~Descriptor() = default; + + size_t get_workspace_size() const { + return 0; + } + + infiniStatus_t calculate(void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *total_kv_len, + void *stream) const { + uint64_t empty_shape[4]; + int64_t empty_strides[4]; + + auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}}; + auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}}; + auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}}; + auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}}; + + NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides}; + NineToothedTensor is_causal; + NineToothedTensor scale{const_cast(&_scale), nullptr, nullptr}; + auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}}; + NineToothedTensor with_attn_mask; + NineToothedTensor causal_variant; + + const auto with_kv_cache_{0}; + const auto emb_dim_{_query_shape[3]}; + const auto is_causal_{_is_causal}; + const auto with_attn_mask_{0}; + const auto causal_variant_{2}; + const auto dtype_{_dtype}; + + constexpr auto block_size_m_{256}; + constexpr auto block_size_n_{64}; + + if (launch_flash_attention(stream, + query, + key, + value, + total_kv_length, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache_, + emb_dim_, + is_causal_, + with_attn_mask_, + causal_variant_, + dtype_, + block_size_m_, + block_size_n_)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + + static infiniStatus_t create(infiniopHandle_t handle, + Descriptor **desc, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + double scale, + char is_causal) { + *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal}; + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector _query_shape; + + std::vector _query_strides; + + std::vector _key_shape; + + std::vector _key_strides; + + std::vector _value_shape; + + std::vector _value_strides; + + std::vector _total_kv_shape; + + std::vector _total_kv_strides; + + std::vector _output_strides; + + infiniDtype_t _dtype; + + double _scale; + + char _is_causal; +}; + +} // namespace op::flash_attention::ninetoothed + +#endif // __FLASH_ATTENTION_DESCRIPTOR_H__ diff --git a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py new file mode 100644 index 000000000..22d63ae4a --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py @@ -0,0 +1,281 @@ +import enum +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE_M = ninetoothed.block_size() +BLOCK_SIZE_N = ninetoothed.block_size() + + +class CausalVariant(enum.IntEnum): + """Please refer to ``_.""" + + UPPER_LEFT = enum.auto() + + LOWER_RIGHT = enum.auto() + + +def arrangement( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache, + block_size_m=None, + block_size_n=None, +): + def arrange_query_or_output(input): + arranged = input.tile((1, 1, block_size_m, -1)).tile( + (1, query.shape[-3] // key.shape[-3], 1, 1) + ) + arranged.dtype = arranged.dtype.squeeze((0, 2, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_key_or_value(input): + arranged = ( + input.tile((1, 1, block_size_n, -1)) + .tile((1, 1, -1, -1)) + .expand((-1, -1, query_arranged.shape[-2], -1)) + ) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_total_kv_len(input, shape): + arranged = input.tile((1,)) + arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape) + return arranged + + def arrange_present_key_or_present_value(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_attn_mask(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 2)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + if block_size_m is None: + block_size_m = BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = BLOCK_SIZE_N + + query_arranged = arrange_query_or_output(query) + key_arranged = arrange_key_or_value(key) + value_arranged = arrange_key_or_value(value) + total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape) + present_key_arranged = arrange_present_key_or_present_value(present_key) + present_value_arranged = arrange_present_key_or_present_value(present_value) + present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot) + present_value_slot_arranged = arrange_present_key_or_present_value( + present_value_slot + ) + attn_mask_arranged = arrange_attn_mask(attn_mask) + is_causal_arranged = is_causal + scale_arranged = scale + output_arranged = arrange_query_or_output(output) + with_attn_mask_arranged = with_attn_mask + causal_variant_arranged = causal_variant + + if with_kv_cache: + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + present_key_arranged, + present_value_arranged, + present_key_slot_arranged, + present_value_slot_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + +def application_with_kv_cache( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + present_key_slot = present_key # noqa: F841 + present_value_slot = present_value # noqa: F841 + + application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + +def application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + actual_kv_len = total_kv_len[0] + + for i in range(query.shape[0]): + query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype) + + acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32) + lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) + max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) + + for j in range(-(-actual_kv_len // key.dtype.shape[0])): + + qk = ntl.dot(query_i, ntl.trans(key[j])) + + key_pos = key[j].offsets(-2) + qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf")) + + if with_attn_mask: + qk += attn_mask[j] + + if is_causal: + query_pos = query[i].offsets(-2) + + if causal_variant == 2: # CausalVariant.LOWER_RIGHT: + mask = ( + query_pos[:, None] + actual_kv_len - query.source.shape[-2] + >= key_pos[None, :] + ) + else: + mask = query_pos[:, None] >= key_pos[None, :] + + qk = ntl.where(mask, qk, float("-inf")) + + next_max = ntl.maximum(max, ntl.max(qk, 1)) + stable_qk = ntl.exp2(qk - next_max[:, None]) + + alpha = ntl.exp2(max - next_max) + acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j]) + max = next_max + lse = lse * alpha + ntl.sum(stable_qk, 1) + + acc /= lse[:, None] + output[i] = acc # noqa: F841 + + +def premake( + with_kv_cache, + emb_dim=None, + is_causal=None, + with_attn_mask=None, + causal_variant=None, + dtype=None, + block_size_m=None, + block_size_n=None, +): + arrangement_ = functools.partial( + arrangement, + with_kv_cache=with_kv_cache, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + query, key, value, attn_mask, output = ( + Tensor( + 4, + dtype=dtype, + shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}), + ) + for _ in range(5) + ) + total_kv_len = Tensor(1, dtype=ninetoothed.int32) + present_key, present_value, present_key_slot, present_value_slot = ( + Tensor(4, dtype=dtype) for _ in range(4) + ) + scale = Tensor(0, dtype=ninetoothed.float64) + is_causal = Tensor(0, constexpr=True, value=is_causal) + with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask) + causal_variant = Tensor(0, constexpr=True, value=causal_variant) + + if emb_dim is not None: + for tensor in (query, key, value, attn_mask, output): + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + if with_kv_cache: + application = application_with_kv_cache + else: + application = application_without_kv_cache + + tensors = ( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc new file mode 100644 index 000000000..ddccf9836 --- /dev/null +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -0,0 +1,124 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/flash_attention.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/flash_attention_cpu.h" +#endif +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) +#include "ninetoothed/descriptor.h" +#endif +#endif + +__C infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t total_kv_len, + float scale, + char is_causal) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::flash_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + q_desc, \ + k_desc, \ + v_desc, \ + total_kv_len, \ + scale, \ + is_causal); + + switch (handle->device) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + const void *total_kv_len, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream); + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/src/infiniop/ops/kv_caching/ninetoothed/build.py b/src/infiniop/ops/kv_caching/ninetoothed/build.py new file mode 100644 index 000000000..03481c86b --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/build.py @@ -0,0 +1,27 @@ +import ninetoothed +from . import kv_caching + +import infiniop.ninetoothed.build + + +def build(): + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "emb_dim": (1, 16, 32, 64, 128, 256), + "dtype": dtype_values, + "block_size_m": (64,), + "block_size_n": (64,), + } + + infiniop.ninetoothed.build.build( + kv_caching.premake, + constexpr_param_grid, + caller="cuda", + op_name="kv_caching", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h new file mode 100644 index 000000000..43388f58d --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h @@ -0,0 +1,101 @@ +#ifndef KV_CACHING_H +#define KV_CACHING_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/kv_caching.h" +#include "../../../ninetoothed/utils.h" + +namespace op::kv_caching::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id}, + k_cache_shape_{k_cache_desc->shape()}, + k_cache_strides_{k_cache_desc->strides()}, + v_cache_shape_{v_cache_desc->shape()}, + v_cache_strides_{v_cache_desc->strides()}, + k_shape_{k_desc->shape()}, + k_strides_{k_desc->strides()}, + v_shape_{v_desc->shape()}, + v_strides_{v_desc->strides()}, + past_kv_lengths_shape_{past_kv_lengths_desc->shape()}, + past_kv_lengths_strides_{past_kv_lengths_desc->strides()}, + dtype_{k_desc->dtype()} {} + + ~Descriptor() = default; + + size_t get_workspace_size() const { return 0; }; + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + *desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths}; + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) const { + auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}}; + auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}}; + auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}}; + auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}}; + auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}}; + + if (launch_kv_caching(stream, + k_cache_nt, + v_cache_nt, + k_nt, + v_nt, + past_kv_lengths_nt, + k_shape_[3], + dtype_, + 64, 64)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector k_cache_shape_; + std::vector k_cache_strides_; + + std::vector v_cache_shape_; + std::vector v_cache_strides_; + + std::vector k_shape_; + std::vector k_strides_; + std::vector v_shape_; + std::vector v_strides_; + + std::vector past_kv_lengths_shape_; + std::vector past_kv_lengths_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::kv_caching::ninetoothed + +#endif // KV_CACHING_H diff --git a/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py new file mode 100644 index 000000000..dfc5088e9 --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py @@ -0,0 +1,66 @@ +import functools +import ninetoothed +from ninetoothed import Tensor + + +def arrangement( + k_cache, + v_cache, + k, + v, + past_lengths, + block_size_m=ninetoothed.block_size(), + block_size_n=ninetoothed.block_size(), +): + k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + + k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + + past_lengths_arranged = ( + past_lengths.tile((1,)) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .expand((-1, *k_arranged.shape)) + ) + + return ( + k_cache_arranged, + v_cache_arranged, + k_arranged, + v_arranged, + past_lengths_arranged, + ) + + +def application(k_cache, v_cache, k, v, past_lengths): + pos = past_lengths + + for i in range(k.shape[-2]): + k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0] + v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0] + + +def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None): + arrangement_ = functools.partial( + arrangement, block_size_m=block_size_m, block_size_n=block_size_n + ) + + shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256}) + + tensors = ( + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(1, dtype=ninetoothed.int64), + ) + + if emb_dim is not None: + for tensor in tensors: + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc new file mode 100644 index 000000000..34bdf9a99 --- /dev/null +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -0,0 +1,143 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/kv_caching.h" + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API) +#include "ninetoothed/kv_caching.h" +#endif +#endif + +__C infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::kv_caching::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + k_cache, \ + v_cache, \ + k, \ + v, \ + past_kv_lengths) + + switch (handle->device) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetKVCachingWorkspaceSize( + infiniopKVCachingDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + GET_SIZE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopKVCaching( + infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream) + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyKVCachingDescriptor( + infiniopKVCachingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + DELETE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/src/infiniop/ops/random_sample_batched/operator.cc b/src/infiniop/ops/random_sample_batched/operator.cc new file mode 100644 index 000000000..d0047ad53 --- /dev/null +++ b/src/infiniop/ops/random_sample_batched/operator.cc @@ -0,0 +1,128 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/random_sample_batched.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/random_sample_batched_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/random_sample_batched_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor( + infiniopHandle_t handle, + infiniopRandomSampleBatchedDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::random_sample::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + result, \ + probs) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize( + infiniopRandomSampleBatchedDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \ + *size = reinterpret_cast(desc)->minWorkspaceSize(); \ + } \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopRandomSampleBatched( + infiniopRandomSampleBatchedDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, \ + result, probs, \ + random_val, \ + topp, topk, temperature, \ + batch_size, \ + stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor( + infiniopRandomSampleBatchedDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/relu/metax/relu_metax.maca b/src/infiniop/ops/relu/metax/relu_metax.maca index 900fce9e0..2c5104bdd 100644 --- a/src/infiniop/ops/relu/metax/relu_metax.maca +++ b/src/infiniop/ops/relu/metax/relu_metax.maca @@ -2,6 +2,7 @@ #include "../../../../../build/ninetoothed/relu.h" #include "../../../devices/metax/metax_common.h" +#include "../../../ninetoothed/utils.h" #include "relu_metax.h" namespace op::relu::metax { @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( } const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu index 22b85a401..a3c79fb52 100644 --- a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu +++ b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu @@ -1,5 +1,6 @@ #ifdef ENABLE_NINETOOTHED #include "../../../../../build/ninetoothed/relu.h" +#include "../../../ninetoothed/utils.h" #endif #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia.cuh" @@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( } #ifdef ENABLE_NINETOOTHED const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/src/infiniop/ops/swiglu/ninetoothed/build.py b/src/infiniop/ops/swiglu/ninetoothed/build.py new file mode 100644 index 000000000..fa4af6db2 --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/build.py @@ -0,0 +1,29 @@ +import ninetoothed +from . import swiglu + +import infiniop.ninetoothed.build + + +def build(): + MAX_NDIM = 5 + + ndim_values = range(1, MAX_NDIM + 1) + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "ndim": ndim_values, + "dtype": dtype_values, + "block_size": (1024,), + } + + infiniop.ninetoothed.build.build( + swiglu.premake, + constexpr_param_grid, + caller="cuda", + op_name="swiglu", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.h b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h new file mode 100644 index 000000000..4aa2fa70e --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h @@ -0,0 +1,82 @@ +#ifndef SWIGLU_H +#define SWIGLU_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/swiglu.h" +#include "../../../ninetoothed/utils.h" + +namespace op::swiglu::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id}, + out_shape_{out_desc->shape()}, + out_strides_{out_desc->strides()}, + up_shape_{input_desc_vec[0]->shape()}, + up_strides_{input_desc_vec[0]->strides()}, + gate_shape_{input_desc_vec[1]->shape()}, + gate_strides_{input_desc_vec[1]->strides()}, + dtype_{out_desc->dtype()} {} + + ~Descriptor() = default; + + size_t workspaceSize() const { + return 0; + } + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + *desc_ptr = new Descriptor(handle, out_desc, input_desc_vec); + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)}; + auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)}; + auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)}; + + if (launch_swiglu(stream, + out_nt, + up_nt, + gate_nt, + out_shape_.size(), + dtype_, + 1024)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector out_shape_; + std::vector out_strides_; + + std::vector up_shape_; + std::vector up_strides_; + + std::vector gate_shape_; + std::vector gate_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::swiglu::ninetoothed + +#endif // SWIGLU_H diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.py b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py new file mode 100644 index 000000000..62074a84b --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(output, up, gate): + output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 9d8e6406a..b3fabba32 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -6,14 +6,22 @@ #include "cpu/swiglu_cpu.h" #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "nvidia/swiglu_nvidia.cuh" #endif +#endif #ifdef ENABLE_KUNLUN_API #include "kunlun/swiglu_kunlun.h" #endif #ifdef ENABLE_METAX_API +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "metax/swiglu_metax.h" #endif +#endif #ifdef ENABLE_CAMBRICON_API #include "bang/swiglu_bang.h" #endif @@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#else CREATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_NVIDIA, ninetoothed); +#else GET(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_METAX, ninetoothed); +#else GET(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API GET(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); +#else CALCULATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CALCULATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API DELETE(INFINI_DEVICE_QY, nvidia); #endif @@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_METAX, ninetoothed); +#else DELETE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API DELETE(INFINI_DEVICE_CAMBRICON, bang); #endif diff --git a/test/infinicore/framework/base.py b/test/infinicore/framework/base.py index 87222b299..80dcb3eb1 100644 --- a/test/infinicore/framework/base.py +++ b/test/infinicore/framework/base.py @@ -342,7 +342,10 @@ def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target for i, inp in enumerate(inputs): if isinstance(inp, torch.Tensor): # Clone only if this input will be used for comparison - if comparison_target == i: + if comparison_target == i or ( + isinstance(comparison_target, (list, tuple)) + and i in comparison_target + ): cloned_inp = clone_torch_tensor(inp) infini_tensor = infinicore_tensor_from_torch(cloned_inp) cloned_tensors.append(cloned_inp) @@ -508,7 +511,9 @@ def run_test(self, device, test_case, config): # Handle multiple outputs comparison # Determine what to compare based on comparison_target - if comparison_target is None: + if comparison_target is None or isinstance( + comparison_target, (list, tuple) + ): # Compare return values (out-of-place multiple outputs) torch_comparison = torch_result infini_comparison = infini_result @@ -573,7 +578,9 @@ def run_test(self, device, test_case, config): # ========================================================================== else: # Determine comparison targets for single output - if comparison_target is None: + if comparison_target is None or isinstance( + comparison_target, (list, tuple) + ): # Compare return values (out-of-place) torch_comparison = torch_result infini_comparison = infini_result diff --git a/test/infinicore/ops/flash_attention.py b/test/infinicore/ops/flash_attention.py new file mode 100644 index 000000000..2d4b09599 --- /dev/null +++ b/test/infinicore/ops/flash_attention.py @@ -0,0 +1,115 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, +) + +# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal) +# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) + +_TEST_CASES_DATA = [ + ((1, 1, 2, 16), (1, 1, 8, 16), (1, 1, 8, 16), None, 0.0, False), + ((1, 2, 128, 16), (1, 2, 256, 16), (1, 2, 256, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 32, 32), (1, 1, 32, 32), None, 0.0, True), + ((1, 8, 256, 16), (1, 8, 512, 16), (1, 8, 512, 16), None, 0.0, True), + ((1, 8, 4, 16), (1, 8, 64, 16), (1, 8, 64, 16), None, 0.0, False), + ((8, 28, 256, 128), (8, 28, 512, 128), (8, 28, 512, 128), None, 0.0, True), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, +} +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + import random + + cases = [] + for q_shape, k_shape, v_shape, attn_mask, dropout_p, is_causal in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + q_spec = TensorSpec.from_tensor(q_shape, None, dtype) + k_spec = TensorSpec.from_tensor(k_shape, None, dtype) + v_spec = TensorSpec.from_tensor(v_shape, None, dtype) + + len_shape = (q_shape[0],) + total_len = random.randint(1, k_shape[2]) + total_kv_len_spec = TensorSpec.from_tensor( + len_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=total_len, + high=total_len + 1, + ) + + kwargs = { + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + } + # remove None keys + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + cases.append( + TestCase( + inputs=[q_spec, k_spec, v_spec, total_kv_len_spec, total_len], + kwargs=kwargs, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="Flash Attention", + ) + ) + + return cases + + +def torch_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + k_slice = k[:, :, :cheat, :] + v_slice = v[:, :, :cheat, :] + return torch.nn.functional.scaled_dot_product_attention( + q, k_slice, v_slice, **kwargs + ) + + +def infini_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + return infinicore.nn.functional.flash_attention(q, k, v, total_kv_len, **kwargs) + + +class OpTest(BaseOperatorTest): + """ScaledDotProductAttention operator test with simplified implementation""" + + def __init__(self): + super().__init__("ScaledDotProductAttention") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_flash_attn(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infini_flash_attn(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/kv_caching.py b/test/infinicore/ops/kv_caching.py new file mode 100644 index 000000000..4ca857586 --- /dev/null +++ b/test/infinicore/ops/kv_caching.py @@ -0,0 +1,134 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, + is_broadcast, +) + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (shape (bs, nkvh, seq_len, hd), strides) +_TEST_CASES_DATA = [ + ((1, 1, 8, 1), None), + ((1, 8, 32, 32), None), + ((8, 8, 64, 32), None), + ((1, 32, 8, 64), (32768, 1024, 64, 1)), + ((4, 8, 32, 16), (65536, 8192, 256, 16)), + ((8, 16, 64, 128), (8388608, 524288, 8192, 1)), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 0}, + infinicore.bfloat16: {"atol": 0, "rtol": 0}, + infinicore.float32: {"atol": 0, "rtol": 0}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + test_cases = [] + + for data in _TEST_CASES_DATA: + import random + + cache_shape = data[0] + kv_shape = ( + cache_shape[0], + cache_shape[1], + random.randint(1, cache_shape[2]), + cache_shape[3], + ) + past_shape = (cache_shape[0],) + + strides = data[1] + + past_length = random.randint(0, cache_shape[2] - kv_shape[2]) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0}) + + cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype) + kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype) + + past_kv_lengths_spec = TensorSpec.from_tensor( + past_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=past_length, + high=past_length + 1, + ) + + test_cases.append( + TestCase( + inputs=[ + cache_spec, + cache_spec, + kv_spec, + kv_spec, + past_kv_lengths_spec, + ], + kwargs={}, + output_spec=None, + comparison_target=[0, 1], + tolerance=tolerance, + description=f"KV Caching", + ) + ) + + return test_cases + + +def torch_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + batch_size, num_kv_heads, _, head_dim = k_cache.shape + seq_len = k.shape[2] + + for b in range(batch_size): + past_len = past_kv_lengths[b].item() + for h in range(num_kv_heads): + k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :] + v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :] + + return k_cache, v_cache + + +def infinicore_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + infinicore.kv_caching(k_cache, v_cache, k, v, past_kv_lengths) + return k_cache, v_cache + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("KV Caching") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_kv_caching(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore_kv_caching(*args, **kwargs) + + +def main(): + test_runner = GenericTestRunner(OpTest) + test_runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/scaled_dot_product_attention.py b/test/infinicore/ops/scaled_dot_product_attention.py index 218420d72..644fb6f99 100644 --- a/test/infinicore/ops/scaled_dot_product_attention.py +++ b/test/infinicore/ops/scaled_dot_product_attention.py @@ -11,17 +11,16 @@ # q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) _TEST_CASES_DATA = [ - ((2, 8, 16), (2, 8, 16), (2, 8, 16), None, 0.0, False), - ((1, 4, 32), (1, 4, 32), (1, 4, 32), None, 0.0, False), - ((2, 6, 12), (2, 6, 12), (2, 6, 12), None, 0.0, True), - ((3, 8, 8), (3, 8, 8), (3, 8, 8), None, 0.0, False), - ((2, 4, 16), (2, 4, 16), (2, 4, 16), None, 0.0, True), - ((1, 2, 64), (1, 2, 64), (1, 2, 64), None, 0.0, False), + ((1, 1, 2, 16), (1, 1, 2, 16), (1, 1, 2, 16), None, 0.0, False), + ((1, 2, 8, 16), (1, 2, 8, 16), (1, 2, 8, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 4, 32), (1, 1, 4, 32), None, 0.0, False), + ((1, 2, 4, 16), (1, 2, 4, 16), (1, 2, 4, 16), None, 0.0, True), + ((1, 1, 2, 64), (1, 1, 2, 64), (1, 1, 2, 64), None, 0.0, False), ] _TOLERANCE_MAP = { infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, - infinicore.float32: {"atol": 1e-4, "rtol": 1e-4}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, } _TENSOR_DTYPES = [infinicore.float16, infinicore.float32] @@ -68,9 +67,8 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) def main(): diff --git a/xmake.lua b/xmake.lua index d5a4ba7f7..6070711a6 100644 --- a/xmake.lua +++ b/xmake.lua @@ -19,7 +19,7 @@ end if is_plat("windows") then set_runtimes("MD") add_ldflags("/utf-8", {force = true}) - add_cxflags("/utf-8", {force = true}) + add_cxxflags("/utf-8", {force = true}) end -- CPU @@ -218,14 +218,15 @@ target("infini-utils") set_warnings("all", "error") if is_plat("windows") then - add_cxflags("/wd4068") + add_cxxflags("/wd4068") if has_config("omp") then - add_cxflags("/openmp") + add_cxxflags("/openmp") end else add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp", {force = true}) end end @@ -270,6 +271,7 @@ target("infinirt") set_languages("cxx17") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")) add_files("src/infinirt/*.cc") diff --git a/xmake/ascend.lua b/xmake/ascend.lua index 6a28979b4..e51626d1d 100644 --- a/xmake/ascend.lua +++ b/xmake/ascend.lua @@ -44,6 +44,7 @@ target("infiniop-ascend") on_install(function (target) end) add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") set_languages("cxx17") @@ -62,6 +63,7 @@ target("infinirt-ascend") -- Add files add_files("$(projectdir)/src/infinirt/ascend/*.cc") add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-ascend") @@ -76,5 +78,6 @@ target("infiniccl-ascend") add_links("libhccl.so") add_files("../src/infiniccl/ascend/*.cc") add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/bang.lua b/xmake/bang.lua index d2195acd5..ffa85ef6d 100644 --- a/xmake/bang.lua +++ b/xmake/bang.lua @@ -41,6 +41,7 @@ target("infiniop-cambricon") on_install(function (target) end) add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") set_languages("cxx17") @@ -59,6 +60,7 @@ target("infinirt-cambricon") -- Add include dirs add_files("../src/infinirt/bang/*.cc") add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-cambricon") @@ -89,6 +91,7 @@ target("infiniccl-cambricon") add_files("../src/infiniccl/cambricon/*.cc") add_cxflags("-fPIC") + add_cxxflags("-fPIC") add_ldflags("-fPIC") else print("[Warning] CNCL is currently only supported on Linux") diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 22dc8f8e7..e192fbbbd 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -6,14 +6,15 @@ target("infiniop-cpu") set_warnings("all", "error") if is_plat("windows") then - add_cxflags("/wd4068") + add_cxxflags("/wd4068") if has_config("omp") then - add_cxflags("/openmp") + add_cxxflags("/openmp") end else add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp") end end @@ -32,6 +33,7 @@ target("infinirt-cpu") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") diff --git a/xmake/hygon.lua b/xmake/hygon.lua index ed4b91f0e..05d3e8356 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -60,6 +60,7 @@ target("infiniop-hygon") add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") @@ -76,7 +77,7 @@ target("infiniop-hygon") add_files("../src/infiniop/ops/swiglu/nvidia/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) end target_end() @@ -105,6 +106,7 @@ target("infinirt-hygon") add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") @@ -138,6 +140,7 @@ target("infiniccl-hygon") add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 35ccf2154..b4ba792fa 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -46,6 +46,7 @@ target("infiniop-iluvatar") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu") @@ -54,7 +55,7 @@ target("infiniop-iluvatar") add_files("../src/infiniop/ops/dequantize_awq/iluvatar/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) end target_end() @@ -73,6 +74,7 @@ target("infinirt-iluvatar") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infinirt/cuda/*.cu") @@ -94,6 +96,7 @@ target("infiniccl-iluvatar") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/kunlun.lua b/xmake/kunlun.lua index 185082b3c..84ba14082 100644 --- a/xmake/kunlun.lua +++ b/xmake/kunlun.lua @@ -75,6 +75,7 @@ target("infiniop-kunlun") on_install(function (target) end) add_cxflags("-lstdc++ -fPIC -Wno-error=unused-function") + add_cxxflags("-lstdc++ -fPIC -Wno-error=unused-function") set_warnings("all", "error") set_languages("cxx17") @@ -102,6 +103,7 @@ target("infinirt-kunlun") -- Add include dirs add_files("$(projectdir)/src/infinirt/kunlun/*.cc") add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-kunlun") @@ -117,5 +119,6 @@ target("infiniccl-kunlun") add_links("bkcl") add_files("$(projectdir)/src/infiniccl/kunlun/*.cc") add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/metax.lua b/xmake/metax.lua index 5561b45db..91672abe0 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -48,11 +48,12 @@ target("infiniop-metax") set_languages("cxx17") set_warnings("all", "error") add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) + add_cxxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) add_files("../src/infiniop/devices/metax/*.cc", "../src/infiniop/ops/*/metax/*.cc") add_files("../src/infiniop/ops/*/metax/*.maca", {rule = "maca"}) if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-include stdlib.h", "-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-include stdlib.h", "-Wno-return-type"}}) end target_end() @@ -63,6 +64,7 @@ target("infinirt-metax") add_deps("infini-utils") set_warnings("all", "error") add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") add_files("../src/infinirt/metax/*.cc") target_end() @@ -73,6 +75,7 @@ target("infiniccl-metax") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end if has_config("ccl") then if has_config("use-mc") then diff --git a/xmake/moore.lua b/xmake/moore.lua index 25eddf522..fdcad9564 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -42,6 +42,7 @@ target("infiniop-moore") set_languages("cxx17") set_warnings("all", "error") add_cxflags("-lstdc++", "-fPIC", "-Wno-comment") + add_cxxflags("-lstdc++", "-fPIC", "-Wno-comment") add_files("../src/infiniop/devices/moore/*.cc") add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"}) @@ -56,6 +57,7 @@ target("infinirt-moore") add_deps("infini-utils") set_warnings("all", "error") add_cxflags("-lstdc++", "-fPIC") + add_cxxflags("-lstdc++", "-fPIC") add_files("../src/infinirt/moore/*.cc") target_end() @@ -66,6 +68,7 @@ target("infiniccl-moore") set_warnings("all", "error") if not is_plat("windows") then add_cxflags("-fPIC") + add_cxxflags("-fPIC") end if has_config("ccl") then add_links("libmccl.so") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 75086b8a1..5752dfefe 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -48,6 +48,7 @@ target("infiniop-nvidia") add_cuflags("-Xcompiler=-fPIC") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_cflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") @@ -93,6 +94,7 @@ target("infinirt-nvidia") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") @@ -112,6 +114,7 @@ target("infiniccl-nvidia") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/qy.lua b/xmake/qy.lua index ecef359a8..1defe8763 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -88,6 +88,7 @@ target("infiniop-qy") add_cuflags("-Xcompiler=-fPIC") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") if CUDNN_ROOT ~= nil then @@ -117,6 +118,7 @@ target("infinirt-qy") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") @@ -133,6 +135,7 @@ target("infiniccl-qy") add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/test.lua b/xmake/test.lua index 002083e1d..56dca6e5f 100644 --- a/xmake/test.lua +++ b/xmake/test.lua @@ -24,7 +24,7 @@ target("infiniop-test") add_links("infiniop", "infinirt") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp") end