diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..a0898460d 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -1,5 +1,10 @@ import contextlib +with contextlib.suppress(ImportError): + from ._preload import preload + preload() + +# import torch import infinicore.context as context import infinicore.nn as nn diff --git a/python/infinicore/_preload.py b/python/infinicore/_preload.py new file mode 100644 index 000000000..fc5ff6560 --- /dev/null +++ b/python/infinicore/_preload.py @@ -0,0 +1,121 @@ +import ctypes +import os +from typing import Iterable, List + + +def _candidate_prefixes(path: str) -> List[str]: + """ + Return HPCC install prefixes to search for libs. + Prefer HPCC_PATH; if absent and explicitly opted-in, fall back to /opt/hpcc. + """ + prefixes: List[str] = [] + if path: + prefixes.append(path) + + seen = set() + unique: List[str] = [] + for p in prefixes: + if p and p not in seen: + seen.add(p) + unique.append(p) + return unique + + +def _try_load(paths: Iterable[str], name: str) -> bool: + """Try to load a shared library from given paths or system search path.""" + for path in paths: + full = os.path.join(path, "lib", name) + if os.path.exists(full): + try: + ctypes.CDLL(full, mode=ctypes.RTLD_GLOBAL) + return True + except OSError: + # Try next candidate + continue + # Last resort: rely on loader search path + try: + ctypes.CDLL(name, mode=ctypes.RTLD_GLOBAL) + return True + except OSError: + return False + + +def preload_hpcc() -> None: + """ + Best-effort preload of key HPCC runtime libs with RTLD_GLOBAL. + + This mirrors the behavior of torch's HPCC build that loads libtorch_global_deps.so, + but avoids introducing a hard torch dependency. All failures are swallowed. + """ + hpcc_path = os.getenv("HPCC_PATH") + if not hpcc_path: + return + + prefixes = _candidate_prefixes(hpcc_path) + libs = [ + "libhcruntime.so", + "libhcToolsExt.so", + "libruntime_cu.so", + "libhccompiler.so", + ] + + for lib in libs: + _try_load(prefixes, lib) + + +def _should_preload_device(device_type: str) -> bool: + """ + Check if preload is needed for a specific device type. + """ + device_env_map = { + "METAX": ["HPCC_PATH", "INFINICORE_PRELOAD_HPCC"], # HPCC/METAX + # Add other device types here as needed: + # "ASCEND": ["ASCEND_PATH"], + # "CAMBRICON": ["NEUWARE_HOME"], + } + + env_vars = device_env_map.get(device_type, []) + for env_var in env_vars: + if os.getenv(env_var): + return True + return False + + +def preload_device(device_type: str) -> None: + """ + Preload runtime libraries for a specific device type if needed. + + Args: + device_type: Device type name (e.g., "METAX", "ASCEND", etc.) + """ + if device_type == "METAX": + preload_hpcc() + # Add other device preload functions here as needed: + # elif device_type == "ASCEND": + # preload_ascend() + # etc. + + +def preload() -> None: + """ + Universal preload function that loops through device types and preloads when required. + + This function detects available device types and preloads their runtime libraries + if the environment indicates they are needed. + """ + # Device types that may require preload + device_types = [ + "METAX", # HPCC/METAX + # Add other device types here as they are implemented: + # "ASCEND", + # "CAMBRICON", + # etc. + ] + + for device_type in device_types: + if _should_preload_device(device_type): + try: + preload_device(device_type) + except Exception: + # Swallow all errors - preload is best-effort + pass diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_metax.h b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.h new file mode 100644 index 000000000..82a5b3e59 --- /dev/null +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_METAX_H__ +#define __PAGED_ATTENTION_METAX_H__ + +#include "../paged_attention.h" + +DESCRIPTOR(metax) + +#endif // __PAGED_ATTENTION_METAX_H__ diff --git a/src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca new file mode 100644 index 000000000..6b06efc37 --- /dev/null +++ b/src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca @@ -0,0 +1,149 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#ifdef ENABLE_METAX_MC_API +#include +#else +#include +#endif + +#include "../../../reduce/cuda/reduce.cuh" +#include "../cuda/kernel.cuh" +#include "paged_attention_metax.h" + +template +INFINIOP_METAX_KERNEL pagedAttention( + Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, + const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes, + const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, + const size_t block_size, + const ptrdiff_t q_stride, + const ptrdiff_t kv_block_stride, + const ptrdiff_t kv_head_stride, + const ptrdiff_t o_stride) { + op::paged_attention::cuda::pagedAttentionKernel( + out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, + max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride); +} + +namespace op::paged_attention::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + const std::optional &alibi_slopes_desc, + float scale) { + auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); + CHECK_RESULT(info); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, + const void *block_tables, const void *seq_lens, const void *alibi_slopes, + size_t num_heads, size_t num_seqs, + size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, + ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride, + hcStream_t stream) { + dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1); + dim3 block(NUM_THREADS); + size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); + + if (dtype == INFINI_DTYPE_F16) { + pagedAttention + <<>>( + (half *)out, + (const half *)q, (const half *)k_cache, (const half *)v_cache, + (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, + scale, max_num_blocks_per_seq, block_size, + q_stride, kv_block_stride, kv_head_stride, o_stride); + } else if (dtype == INFINI_DTYPE_BF16) { + pagedAttention + <<>>( + (cuda_bfloat16 *)out, (const cuda_bfloat16 *)q, (const cuda_bfloat16 *)k_cache, (const cuda_bfloat16 *)v_cache, + (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, + scale, max_num_blocks_per_seq, block_size, + q_stride, kv_block_stride, kv_head_stride, o_stride); + } else if (dtype == INFINI_DTYPE_F32) { + pagedAttention + <<>>( + (float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache, + (const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads, + scale, max_num_blocks_per_seq, block_size, + q_stride, kv_block_stride, kv_head_stride, o_stride); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *seq_lens, const void *alibi_slopes, + void *stream_) const { + hcStream_t stream = (hcStream_t)stream_; + +#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \ + launchKernel<__H_SIZE, __B_SIZE>( \ + out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \ + _info.num_heads, _info.num_seqs, \ + _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \ + _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \ + stream); + +#define SWITCH_HEAD_SIZE(__B_SIZE) \ + switch (_info.head_size) { \ + case 16: \ + LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \ + break; \ + case 32: \ + LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \ + break; \ + case 64: \ + LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \ + break; \ + case 128: \ + LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \ + break; \ + case 256: \ + LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \ + break; \ + default: \ + return INFINI_STATUS_BAD_TENSOR_SHAPE; \ + } + + int max_threads = _opaque->internal->maxThreadsPerBlock(); + if (max_threads >= METAX_BLOCK_SIZE_1024) { + SWITCH_HEAD_SIZE(METAX_BLOCK_SIZE_1024) + } else if (max_threads >= METAX_BLOCK_SIZE_512) { + SWITCH_HEAD_SIZE(METAX_BLOCK_SIZE_512) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + +#undef LAUNCH_HEADSIZE_BLOCKSIZE +#undef SWITCH_HEAD_SIZE + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::paged_attention::metax diff --git a/src/infiniop/ops/paged_attention/operator.cc b/src/infiniop/ops/paged_attention/operator.cc index 1d7d4fee3..46bea9e1e 100644 --- a/src/infiniop/ops/paged_attention/operator.cc +++ b/src/infiniop/ops/paged_attention/operator.cc @@ -5,9 +5,9 @@ #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_attention_nvidia.cuh" #endif -// #ifdef ENABLE_METAX_API -// #include "metax/paged_attention_metax.h" -// #endif +#ifdef ENABLE_METAX_API +#include "metax/paged_attention_metax.h" +#endif __C infiniStatus_t infiniopCreatePagedAttentionDescriptor( infiniopHandle_t handle, @@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor( #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CREATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( #ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // GET(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention( #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CALCULATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( #ifdef ENABLE_NVIDIA_API DESTROY(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // DESTROY(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } diff --git a/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h new file mode 100644 index 000000000..03b6cef3c --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__ +#define __PAGED_ATTENTION_PREFILL_METAX_H__ + +#include "../paged_attention_prefill.h" + +DESCRIPTOR(metax) + +#endif // __PAGED_ATTENTION_PREFILL_METAX_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca new file mode 100644 index 000000000..6a2f3a722 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/metax/paged_attention_prefill_metax.maca @@ -0,0 +1,129 @@ +#include +#include +#include + +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" +#include "paged_attention_prefill_metax.h" + +template +infiniStatus_t launchPagedAttentionPrefill( + Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, + const int64_t *block_tables, + const int64_t *seq_lens, + const int64_t *cum_seq_lens_q, + const float *alibi_slopes, + const size_t num_heads, + const size_t num_seqs, + const size_t num_kv_heads, + const float scale, + const size_t max_num_blocks_per_seq, + const size_t block_size, + const size_t total_q_tokens, + const size_t head_size, + const ptrdiff_t kv_block_stride, + const ptrdiff_t kv_head_stride, + const ptrdiff_t q_stride, + const ptrdiff_t q_head_stride, + hcStream_t stream) { + + if (total_q_tokens == 0 || num_heads == 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (head_size == 0 || head_size > 1024) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + dim3 grid(static_cast(total_q_tokens), static_cast(num_heads)); + dim3 block(static_cast(head_size)); + + op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel + <<>>( + out, q, k_cache, v_cache, + block_tables, seq_lens, cum_seq_lens_q, alibi_slopes, + num_heads, num_kv_heads, scale, + max_num_blocks_per_seq, block_size, + kv_block_stride, kv_head_stride, + q_stride, q_head_stride, + head_size, + num_seqs); + + return INFINI_STATUS_SUCCESS; +} + +namespace op::paged_attention_prefill::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t cum_seq_lens_q_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto info = PagedAttentionPrefillInfo::create( + out_desc, q_desc, k_cache_desc, v_cache_desc, + block_tables_desc, seq_lens_desc, + cum_seq_lens_q_desc, + alibi_slopes_desc, scale); + + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *cum_seq_lens_q, + const void *alibi_slopes, + void *stream_) const { + + hcStream_t stream = (hcStream_t)stream_; + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + launchPagedAttentionPrefill( \ + (Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \ + (const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \ + (const float *)alibi_slopes, \ + _info.num_heads, _info.num_seqs, _info.num_kv_heads, \ + _info.scale, _info.max_num_blocks_per_seq, \ + _info.block_size, _info.total_q_tokens, \ + _info.head_size, \ + _info.kv_block_stride, _info.kv_head_stride, \ + _info.q_stride, _info.q_head_stride, \ + stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + return LAUNCH_KERNEL(half, float); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + return LAUNCH_KERNEL(cuda_bfloat16, float); + } else if (_info.dtype == INFINI_DTYPE_F32) { + return LAUNCH_KERNEL(float, float); + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::paged_attention_prefill::metax diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc index e205acca1..af21df651 100644 --- a/src/infiniop/ops/paged_attention_prefill/operator.cc +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -5,6 +5,9 @@ #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_attention_prefill_nvidia.cuh" #endif +#ifdef ENABLE_METAX_API +#include "metax/paged_attention_prefill_metax.h" +#endif __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( infiniopHandle_t handle, @@ -32,6 +35,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( switch (handle->device) { #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -50,6 +56,9 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -75,6 +84,9 @@ __C infiniStatus_t infiniopPagedAttentionPrefill( switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -92,6 +104,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( switch (desc->device_type) { #ifdef ENABLE_NVIDIA_API DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/paged_caching/metax/paged_caching_metax.h b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.h new file mode 100644 index 000000000..7ac3fda2c --- /dev/null +++ b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.h @@ -0,0 +1,8 @@ +#ifndef __PAGED_CACHING_METAX_H__ +#define __PAGED_CACHING_METAX_H__ + +#include "../paged_caching.h" + +DESCRIPTOR(metax) + +#endif // __PAGED_CACHING_METAX_H__ diff --git a/src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca new file mode 100644 index 000000000..db761992f --- /dev/null +++ b/src/infiniop/ops/paged_caching/metax/paged_caching_metax.maca @@ -0,0 +1,157 @@ +#include "../../../devices/metax/metax_common.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" +#include "paged_caching_metax.h" + +template +INFINIOP_METAX_KERNEL pagedCaching( + Tdata *k_cache, Tdata *v_cache, + const Tdata *k, const Tdata *v, + const int64_t *slot_mapping, + const size_t head_size, const size_t block_size, + const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, + const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride) { + op::paged_caching::cuda::pagedCachingKernel( + k_cache, v_cache, k, v, slot_mapping, head_size, + block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); +} + +namespace op::paged_caching::metax { +// PIMPL struct definition +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor implementation +Descriptor::~Descriptor() { + delete _opaque; +} + +// Static factory method implementation +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t slot_mapping_desc) { + + auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc); + CHECK_RESULT(info); + + // Create and return the Descriptor instance. + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +// The launchKernel function is a templated helper to encapsulate the kernel launch. +// It sets up grid/block dimensions and calls the device-side kernel. +template +infiniStatus_t launchKernel(const PagedCachingInfo &info, + void *k_cache, void *v_cache, + infiniDtype_t dtype, + const void *k, const void *v, + const void *slot_mapping, + size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, + ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, + ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, + hcStream_t stream) { + + // Grid dimension is 1D, with one block per token, as we decided. + dim3 grid(uint64_t(num_kv_heads), uint64_t(num_tokens), 1); + // Block dimension is 1D, using the number of threads specified at compile time. + dim3 block(NUM_THREADS); + + // This kernel does not require dynamic shared memory. + size_t shared_mem_size = 0; + + // Launch the device-side kernel. + if (dtype == INFINI_DTYPE_F16) { + pagedCaching + <<>>( + (half *)k_cache, + (half *)v_cache, + (const half *)k, + (const half *)v, + (const int64_t *)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride); + } else if (dtype == INFINI_DTYPE_BF16) { + pagedCaching + <<>>( + (cuda_bfloat16 *)k_cache, + (cuda_bfloat16 *)v_cache, + (const cuda_bfloat16 *)k, + (const cuda_bfloat16 *)v, + (const int64_t *)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride); + } else if (dtype == INFINI_DTYPE_F32) { + pagedCaching + <<>>( + (float *)k_cache, + (float *)v_cache, + (const float *)k, + (const float *)v, + (const int64_t *)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +// Execution method implementation +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *k_cache, void *v_cache, + const void *k, const void *v, + const void *slot_mapping, + void *stream_) const { + + hcStream_t stream = (hcStream_t)stream_; + + // Dispatch logic based on the device's maximum threads per block. + // This allows selecting the largest, most efficient block size the hardware supports. + int max_threads = _opaque->internal->maxThreadsPerBlock(); + if (max_threads >= METAX_BLOCK_SIZE_1024) { + // Dispatch based on data type for a 1024-thread block. + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); + } else if (max_threads >= METAX_BLOCK_SIZE_512) { + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); + } else { + // If the device supports fewer threads, return an error. + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::paged_caching::metax diff --git a/src/infiniop/ops/paged_caching/operator.cc b/src/infiniop/ops/paged_caching/operator.cc index 3bfd92280..6eb746f9f 100644 --- a/src/infiniop/ops/paged_caching/operator.cc +++ b/src/infiniop/ops/paged_caching/operator.cc @@ -5,9 +5,9 @@ #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_caching_nvidia.cuh" #endif -// #ifdef ENABLE_METAX_API -// #include "metax/paged_caching_metax.h" -// #endif +#ifdef ENABLE_METAX_API +#include "metax/paged_caching_metax.h" +#endif __C infiniStatus_t infiniopCreatePagedCachingDescriptor( infiniopHandle_t handle, @@ -29,9 +29,9 @@ __C infiniStatus_t infiniopCreatePagedCachingDescriptor( #ifdef ENABLE_NVIDIA_API CREATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CREATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -50,9 +50,9 @@ __C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( #ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // GET(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -75,9 +75,9 @@ __C infiniStatus_t infiniopPagedCaching( #ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // CALCULATE(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -95,9 +95,9 @@ __C infiniStatus_t infiniopDestroyPagedCachingDescriptor( #ifdef ENABLE_NVIDIA_API DESTROY(INFINI_DEVICE_NVIDIA, nvidia) #endif - // #ifdef ENABLE_METAX_API - // DESTROY(INFINI_DEVICE_METAX, metax) - // #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax) +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; }