Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import contextlib

with contextlib.suppress(ImportError):
from ._preload import preload
preload()

# import torch
import infinicore.context as context
import infinicore.nn as nn

Expand Down
121 changes: 121 additions & 0 deletions python/infinicore/_preload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import ctypes
import os
from typing import Iterable, List


def _candidate_prefixes(path: str) -> List[str]:
"""
Return HPCC install prefixes to search for libs.
Prefer HPCC_PATH; if absent and explicitly opted-in, fall back to /opt/hpcc.
"""
prefixes: List[str] = []
if path:
prefixes.append(path)

seen = set()
unique: List[str] = []
for p in prefixes:
if p and p not in seen:
seen.add(p)
unique.append(p)
return unique


def _try_load(paths: Iterable[str], name: str) -> bool:
"""Try to load a shared library from given paths or system search path."""
for path in paths:
full = os.path.join(path, "lib", name)
if os.path.exists(full):
try:
ctypes.CDLL(full, mode=ctypes.RTLD_GLOBAL)
return True
except OSError:
# Try next candidate
continue
# Last resort: rely on loader search path
try:
ctypes.CDLL(name, mode=ctypes.RTLD_GLOBAL)
return True
except OSError:
return False


def preload_hpcc() -> None:
"""
Best-effort preload of key HPCC runtime libs with RTLD_GLOBAL.

This mirrors the behavior of torch's HPCC build that loads libtorch_global_deps.so,
but avoids introducing a hard torch dependency. All failures are swallowed.
"""
hpcc_path = os.getenv("HPCC_PATH")
if not hpcc_path:
return

prefixes = _candidate_prefixes(hpcc_path)
libs = [
"libhcruntime.so",
"libhcToolsExt.so",
"libruntime_cu.so",
"libhccompiler.so",
]

for lib in libs:
_try_load(prefixes, lib)


def _should_preload_device(device_type: str) -> bool:
"""
Check if preload is needed for a specific device type.
"""
device_env_map = {
"METAX": ["HPCC_PATH", "INFINICORE_PRELOAD_HPCC"], # HPCC/METAX
# Add other device types here as needed:
# "ASCEND": ["ASCEND_PATH"],
# "CAMBRICON": ["NEUWARE_HOME"],
}

env_vars = device_env_map.get(device_type, [])
for env_var in env_vars:
if os.getenv(env_var):
return True
return False


def preload_device(device_type: str) -> None:
"""
Preload runtime libraries for a specific device type if needed.

Args:
device_type: Device type name (e.g., "METAX", "ASCEND", etc.)
"""
if device_type == "METAX":
preload_hpcc()
# Add other device preload functions here as needed:
# elif device_type == "ASCEND":
# preload_ascend()
# etc.


def preload() -> None:
"""
Universal preload function that loops through device types and preloads when required.

This function detects available device types and preloads their runtime libraries
if the environment indicates they are needed.
"""
# Device types that may require preload
device_types = [
"METAX", # HPCC/METAX
# Add other device types here as they are implemented:
# "ASCEND",
# "CAMBRICON",
# etc.
]

for device_type in device_types:
if _should_preload_device(device_type):
try:
preload_device(device_type)
except Exception:
# Swallow all errors - preload is best-effort
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __PAGED_ATTENTION_METAX_H__
#define __PAGED_ATTENTION_METAX_H__

#include "../paged_attention.h"

DESCRIPTOR(metax)

#endif // __PAGED_ATTENTION_METAX_H__
149 changes: 149 additions & 0 deletions src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#ifdef ENABLE_METAX_MC_API
#include <mccub/block/block_reduce.cuh>
#else
#include <hccub/block/block_reduce.cuh>
#endif

#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_metax.h"

template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
INFINIOP_METAX_KERNEL pagedAttention(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
const size_t block_size,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride) {
op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
}

namespace op::paged_attention::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);

return INFINI_STATUS_SUCCESS;
}

template <size_t HEAD_SIZE, size_t NUM_THREADS>
infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
size_t num_heads, size_t num_seqs,
size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
hcStream_t stream) {
dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
dim3 block(NUM_THREADS);
size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);

if (dtype == INFINI_DTYPE_F16) {
pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)out,
(const half *)q, (const half *)k_cache, (const half *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedAttention<cuda_bfloat16, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)out, (const cuda_bfloat16 *)q, (const cuda_bfloat16 *)k_cache, (const cuda_bfloat16 *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;

#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);

#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}

int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
SWITCH_HEAD_SIZE(METAX_BLOCK_SIZE_1024)
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
SWITCH_HEAD_SIZE(METAX_BLOCK_SIZE_512)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}

#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE

return INFINI_STATUS_SUCCESS;
}

} // namespace op::paged_attention::metax
30 changes: 15 additions & 15 deletions src/infiniop/ops/paged_attention/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_attention_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif

__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
Expand All @@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__
#define __PAGED_ATTENTION_PREFILL_METAX_H__

#include "../paged_attention_prefill.h"

DESCRIPTOR(metax)

#endif // __PAGED_ATTENTION_PREFILL_METAX_H__
Loading
Loading