diff --git a/CMakeLists.txt b/CMakeLists.txt index a8170dfdc..697968efe 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,6 +3,7 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON) option(USE_NPU "Enable NPU support" OFF) option(USE_MLU "Enable MLU support" OFF) +option(USE_ILU "Enable ILU support" OFF) option(USE_CUDA "Enable CUDA support" OFF) add_compile_definitions(YLT_ENABLE_IBV) add_definitions(-DYLT_ENABLE_IBV) @@ -105,7 +106,7 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS ON) -if(USE_NPU) +if(USE_NPU OR USE_ILU) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) elseif(USE_MLU OR USE_CUDA) @@ -208,6 +209,19 @@ if(USE_CUDA) message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}") endif() +if(USE_ILU) + set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules;${CMAKE_MODULE_PATH}") + set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + set(CMAKE_CUDA_ARCHITECTURES "ivcore11") + set(WARNINGS_AS_ERRORS OFF) + if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + add_definitions( + -Wno-c++11-narrowing + -Wno-thread-safety-analysis + ) + endif() +endif() + # configure vcpkg # have to set CMAKE_TOOLCHAIN_FILE before first project call. # if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE) @@ -425,6 +439,23 @@ if(USE_CUDA) ) endif() +if(USE_ILU) + add_definitions(-DUSE_ILU) + set(CMAKE_VERBOSE_MAKEFILE ON) + include_directories( + $ENV{PYTHON_INCLUDE_PATH} + $ENV{PYTORCH_INSTALL_PATH}/include + $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include + $ENV{IXFORMER_INSTALL_PATH}/csrc/include/ixformer + ) + + link_directories( + $ENV{PYTHON_LIB_PATH} + $ENV{PYTORCH_INSTALL_PATH}/lib + $ENV{IXFORMER_INSTALL_PATH} + ) +endif() + # check if USE_CXX11_ABI is set correctly # if (DEFINED USE_CXX11_ABI) # parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS") diff --git a/setup.py b/setup.py index 63776ebb2..aacc0a15d 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,12 @@ def get_device_type(): if torch.cuda.is_available(): return "cuda" + try: + import ixformer + return "ilu" + except ImportError: + pass + try: import torch_mlu if torch.mlu.is_available(): @@ -143,6 +149,14 @@ def get_torch_mlu_root_path(): except ImportError: return None +def get_ixformer_root_path(): + try: + import ixformer + import os + return os.path.dirname(os.path.abspath(ixformer.__file__)) + except ImportError: + return None + def get_nccl_root_path(): try: from nvidia import nccl @@ -253,7 +267,14 @@ def set_cuda_envs(): os.environ["LIBTORCH_ROOT"] = get_torch_root_path() os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path() os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda" - + +def set_ilu_envs(): + os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path() + os.environ["PYTHON_LIB_PATH"] = get_torch_root_path() + os.environ["LIBTORCH_ROOT"] = get_torch_root_path() + os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path() + os.environ["IXFORMER_INSTALL_PATH"] = get_ixformer_root_path() + class CMakeExtension(Extension): def __init__(self, name: str, path: str, sourcedir: str = "") -> None: super().__init__(name, sources=[]) @@ -337,7 +358,7 @@ def build_extension(self, ext: CMakeExtension): f"-DDEVICE_ARCH={self.arch.upper()}", f"-DINSTALL_XLLM_KERNELS={'ON' if self.install_xllm_kernels else 'OFF'}", ] - + if self.device == "a2" or self.device == "a3": cmake_args += ["-DUSE_NPU=ON"] # set npu environment variables @@ -352,6 +373,9 @@ def build_extension(self, ext: CMakeExtension): f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"] # set cuda environment variables set_cuda_envs() + elif self.device == "ilu": + cmake_args += ["-DUSE_ILU=ON"] + set_ilu_envs() else: raise ValueError("Please set --device to a2 or a3 or mlu or cuda.") @@ -375,6 +399,7 @@ def build_extension(self, ext: CMakeExtension): build_args = ["--config", build_type] max_jobs = os.getenv("MAX_JOBS", str(os.cpu_count())) + # max_jobs="2" build_args += ["-j" + max_jobs] env = os.environ.copy() @@ -604,9 +629,9 @@ def parse_arguments(): parser.add_argument( '--device', type=str.lower, - choices=['auto', 'a2', 'a3', 'mlu', 'cuda'], + choices=['auto', 'a2', 'a3', 'mlu', 'cuda', 'ilu'], default='auto', - help='Device type: a2, a3, mlu, or cuda (case-insensitive)' + help='Device type: a2, a3, mlu, ilu or cuda (case-insensitive)' ) parser.add_argument( diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7fb8fa937..2202cb8d2 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -19,4 +19,32 @@ target_include_directories(mooncake_store PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/Mooncake/mooncake-transfer-engine/include ) +if(USE_ILU) + if(TARGET cpprest) + set_target_properties(cpprest PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF + ) + endif() + if(TARGET transfer_engine) + target_compile_options(transfer_engine PRIVATE -std=c++20) + set_target_properties(transfer_engine PROPERTIES + CXX_STANDARD 20 + CXX_STANDARD_REQUIRED ON + ) + message(STATUS "Set C++20 for transfer_engine target") + endif() + if(TARGET SMHasherSupport) + set_target_properties(SMHasherSupport PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED ON + CXX_EXTENSIONS OFF + ) + message(STATUS "SMHasherSupport target found and configured") + else() + message(WARNING "SMHasherSupport target not found after adding smhasher") + endif() +endif() + target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator) diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp old mode 100755 new mode 100644 index 89646f3df..bd2e2ffc8 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -207,7 +207,7 @@ void BatchInputBuilder::process_sequences_multithreaded() { state_.q_seq_lens.insert(state_.q_seq_lens.end(), state.q_seq_lens.begin(), state.q_seq_lens.end()); -#elif defined(USE_MLU) || defined(USE_CUDA) +#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU) int32_t seq_len_offset = state_.seq_lens.back(); // skip the first element which is 0 for (size_t i = 1; i < state.seq_lens.size(); ++i) { @@ -293,7 +293,7 @@ void BatchInputBuilder::process_single_sequence( #if defined(USE_NPU) state.seq_lens.push_back(seq_len + offset); state.q_seq_lens.push_back(q_seq_len); -#elif defined(USE_MLU) || defined(USE_CUDA) +#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU) state.seq_lens.push_back(state.seq_lens.back() + seq_len + offset); state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len); #endif @@ -527,7 +527,7 @@ void BatchInputBuilder::padding_decode_batch_size( #if defined(USE_NPU) state_.seq_lens.push_back(num_decoding_tokens); state_.q_seq_lens.push_back(num_decoding_tokens); -#elif defined(USE_MLU) || defined(USE_CUDA) +#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU) state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens); state_.q_seq_lens.push_back(state_.q_seq_lens.back() + num_decoding_tokens); diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index f4c8ad27f..43ab5dfee 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -85,7 +85,7 @@ class BatchInputBuilder { #if defined(USE_NPU) std::vector seq_lens; std::vector q_seq_lens; -#elif defined(USE_MLU) || defined(USE_CUDA) +#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU) std::vector seq_lens = {0}; // cu_seq_lens std::vector q_seq_lens = {0}; // q_cu_seq_len #endif diff --git a/xllm/core/framework/parallel_state/CMakeLists.txt b/xllm/core/framework/parallel_state/CMakeLists.txt index f40e73d4b..ad68cd06b 100644 --- a/xllm/core/framework/parallel_state/CMakeLists.txt +++ b/xllm/core/framework/parallel_state/CMakeLists.txt @@ -12,6 +12,7 @@ cc_library( $<$:npu_process_group.h> $<$:mlu_process_group.h> $<$:cuda_process_group.h> + $<$:ilu_process_group.h> collective_communicator.h SRCS mapping_npu.cpp diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index 2f08fb389..60a867431 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -25,6 +25,8 @@ limitations under the License. #include "mlu_process_group.h" #elif defined(USE_CUDA) #include "cuda_process_group.h" +#elif defined(USE_ILU) +#include "ilu_process_group.h" #endif #include "common/global_flags.h" #include "parallel_args.h" diff --git a/xllm/core/framework/parallel_state/ilu_process_group.h b/xllm/core/framework/parallel_state/ilu_process_group.h new file mode 100644 index 000000000..bc71da369 --- /dev/null +++ b/xllm/core/framework/parallel_state/ilu_process_group.h @@ -0,0 +1,55 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "process_group.h" + +namespace xllm { + +class ProcessGroupImpl : public ProcessGroup { + public: + ProcessGroupImpl(int32_t global_rank, + int32_t world_size, + int32_t rank_size, + int32_t port, + bool trans, + const std::string& host, + const std::string& group_name, + const torch::Device& device) + : ProcessGroup(device) { + c10::intrusive_ptr pg_options = + c10d::ProcessGroupNCCL::Options::create(); +#if TORCH_VERSION_MAJOR > 2 || \ + (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7) + pg_options->group_name = group_name; +#endif + int32_t rank = global_rank; + if (world_size != rank_size) { + auto [local_rank, group_ranks] = + get_group_rank(world_size, global_rank, rank_size, trans); + pg_options->global_ranks_in_group = group_ranks; + rank = local_rank; + } + + auto store = create_tcp_store(host, port, rank); + pg_ = std::make_unique( + store, rank, rank_size, pg_options); + } +}; + +} // namespace xllm diff --git a/xllm/core/framework/parallel_state/process_group.cpp b/xllm/core/framework/parallel_state/process_group.cpp index 1b8789305..57d7216bf 100644 --- a/xllm/core/framework/parallel_state/process_group.cpp +++ b/xllm/core/framework/parallel_state/process_group.cpp @@ -21,6 +21,8 @@ limitations under the License. #include "mlu_process_group.h" #elif defined(USE_CUDA) #include "cuda_process_group.h" +#elif defined(USE_ILU) +#include "ilu_process_group.h" #endif namespace { diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index 3bba0e16b..4214bb522 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -12,6 +12,10 @@ if(USE_CUDA) add_subdirectory(cuda) endif() +if(USE_ILU) + add_subdirectory(ilu) +endif() + cc_library( NAME kernels @@ -25,4 +29,5 @@ cc_library( $<$:npu_kernels> $<$:mlu_kernels> $<$:cuda_kernels> + $<$:ilu_kernels> ) \ No newline at end of file diff --git a/xllm/core/kernels/ilu/CMakeLists.txt b/xllm/core/kernels/ilu/CMakeLists.txt new file mode 100644 index 000000000..fa26c8865 --- /dev/null +++ b/xllm/core/kernels/ilu/CMakeLists.txt @@ -0,0 +1,28 @@ +include(cc_library) +set(CMAKE_CUDA_ARCHITECTURES ivcore11) +file(GLOB_RECURSE ILU_HEADER_FILES + "${CMAKE_CURRENT_LIST_DIR}/*.h" +) + +file(GLOB_RECURSE ILU_SOURCE_FILES + "${CMAKE_CURRENT_LIST_DIR}/*.cpp" + "${CMAKE_CURRENT_LIST_DIR}/*.cu" +) + +find_package(Python3 REQUIRED COMPONENTS Interpreter Development) + +cc_library( + NAME + ilu_kernels + HDRS + ${ILU_HEADER_FILES} + SRCS + ${ILU_SOURCE_FILES} + DEPS + torch + :util + ixformer_kernels + ixformer + ${Python3_LIBRARIES} + cuinfer +) diff --git a/xllm/core/kernels/ilu/activation.cpp b/xllm/core/kernels/ilu/activation.cpp new file mode 100644 index 000000000..ae2a16ba5 --- /dev/null +++ b/xllm/core/kernels/ilu/activation.cpp @@ -0,0 +1,32 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ilu_ops_api.h" + +using namespace ixformer; + +namespace xllm::kernel::ilu { + +void act_and_mul(torch::Tensor out, + torch::Tensor input, + const std::string& act_mode) { + if (act_mode == "silu") { + infer::silu_and_mul(input, out); + } else { + LOG(FATAL) << "Unsupported act mode: " << act_mode + << ", only support silu, gelu, gelu_tanh"; + } +} +} // namespace xllm::kernel::ilu diff --git a/xllm/core/kernels/ilu/attention.cpp b/xllm/core/kernels/ilu/attention.cpp new file mode 100644 index 000000000..90e85ee26 --- /dev/null +++ b/xllm/core/kernels/ilu/attention.cpp @@ -0,0 +1,165 @@ + +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ilu_ops_api.h" +#include "ixinfer.h" +#include "utils.h" + +using namespace ixformer; + +namespace xllm::kernel::ilu { + +void reshape_paged_cache(torch::Tensor& key, + std::optional& value, + torch::Tensor& key_cache, + std::optional& value_cache, + torch::Tensor& slot_mapping) { + auto value_ = value.value_or(torch::Tensor()); + auto value_cache_ = value_cache.value_or(torch::Tensor()); + + int64_t key_token_stride = key.stride(0); + int64_t value_token_stride = 0; + if (value_.defined()) { + value_token_stride = value_.stride(0); + } + slot_mapping = slot_mapping.to(at::kLong); + // translate kvcache shape from [n_blocks, block_size, n_heads, head_dim] to + // (num_blocks, num_heads, block_size, head_size) + key_cache = key_cache.permute({0, 2, 1, 3}).contiguous(); + if (value_cache_.defined()) { + value_cache_ = value_cache_.permute({0, 2, 1, 3}).contiguous(); + } + infer::vllm_reshape_and_cache(key, + value_, + key_cache, + value_cache_, + slot_mapping, + key_token_stride, + value_token_stride); +} + +void batch_prefill(torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& output, + std::optional& output_lse, + const std::optional& q_cu_seq_lens, + const std::optional& kv_cu_seq_lens, + const std::optional& alibi_slope, + const std::optional& attn_bias, + const std::optional& q_quant_scale, + const std::optional& k_quant_scale, + const std::optional& v_quant_scale, + int64_t max_query_len, + int64_t max_seq_len, + float scale, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + const std::string& compute_dtype, + bool return_lse) { + double softcap = 0.0; + bool sqrt_alibi = false; + auto q_cu_seq_lens_ = q_cu_seq_lens.value_or(torch::Tensor()); + auto kv_cu_seq_lens_ = kv_cu_seq_lens.value_or(torch::Tensor()); + auto q_quant_scale_ = q_quant_scale.value_or(torch::Tensor()); + auto k_quant_scale_ = k_quant_scale.value_or(torch::Tensor()); + auto v_quant_scale_ = v_quant_scale.value_or(torch::Tensor()); + + infer::ixinfer_flash_attn_unpad(query, + key, + value, + output, + q_cu_seq_lens_, + kv_cu_seq_lens_, + max_query_len, + max_seq_len, + is_causal, + window_size_left, + window_size_right, + static_cast(scale), + softcap, + sqrt_alibi, + alibi_slope, + c10::nullopt, + output_lse); +} + +void batch_decode(torch::Tensor& query, + torch::Tensor& k_cache, + torch::Tensor& output, + torch::Tensor& block_table, + torch::Tensor& seq_lens, + const std::optional& v_cache, + std::optional& output_lse, + const std::optional& q_quant_scale, + const std::optional& k_cache_quant_scale, + const std::optional& v_cache_quant_scale, + const std::optional& out_quant_scale, + const std::optional& alibi_slope, + const std::optional& mask, + const std::string& compute_dtype, + int64_t max_seq_len, + int64_t window_size_left, + int64_t window_size_right, + float scale, + bool return_lse, + bool is_causal, + int64_t kv_cache_quant_bit_size) { + if (query.dim() == 4) { + query = + query + .view({query.size(0) * query.size(1), query.size(2), query.size(3)}) + .contiguous(); + } + if (output.dim() == 4) { + output = output + .view({output.size(0) * output.size(1), + output.size(2), + output.size(3)}) + .contiguous(); + } + auto v_cache_ = v_cache.value_or(torch::Tensor()); + k_cache = k_cache.permute({0, 2, 1, 3}).contiguous(); + v_cache_ = v_cache_.permute({0, 2, 1, 3}).contiguous(); + int64_t num_kv_heads = k_cache.size(1); + int64_t page_block_size = k_cache.size(2); + double softcap = 0.0; + bool enable_cuda_graph = false; + bool use_sqrt_alibi = false; + // check_tensor_contiguous(k_cache, query.dtype()); + + infer::vllm_paged_attention(output, + query, + k_cache, + v_cache_, + static_cast(num_kv_heads), + scale, + block_table, + seq_lens, + page_block_size, + max_seq_len, + alibi_slope, + is_causal, + window_size_left, + window_size_right, + softcap, + enable_cuda_graph, + use_sqrt_alibi, + c10::nullopt); +} + +} // namespace xllm::kernel::ilu \ No newline at end of file diff --git a/xllm/core/kernels/ilu/ilu_ops_api.h b/xllm/core/kernels/ilu/ilu_ops_api.h new file mode 100644 index 000000000..236ce267c --- /dev/null +++ b/xllm/core/kernels/ilu/ilu_ops_api.h @@ -0,0 +1,121 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "ATen/Tensor.h" +#include "ATen/cuda/CUDAEvent.h" +#include "c10/core/Device.h" +#include "c10/core/DeviceGuard.h" +#include "c10/core/GradMode.h" +#include "c10/core/InferenceMode.h" +#include "c10/core/MemoryFormat.h" +#include "c10/core/ScalarType.h" +#include "c10/core/TensorOptions.h" +#include "c10/cuda/CUDAFunctions.h" +#include "c10/cuda/CUDAGuard.h" +#include "c10/cuda/CUDAStream.h" +#include "ixformer.h" +#include "kernels/kernels.h" + +// #include "utils.h" +using namespace ixformer; + +namespace xllm::kernel::ilu { + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& cos_sin_cache, + torch::Tensor& positions, + bool interleave); + +// act_mode only support silu, gelu, gelu_tanh +void act_and_mul(torch::Tensor out, + torch::Tensor input, + const std::string& act_mode); + +void reshape_paged_cache( + torch::Tensor& key, // (num_tokens, num_heads, head_size) + std::optional& value, // (num_tokens, num_heads, head_size) + torch::Tensor& key_cache, // (num_blocks, num_heads, block_size, head_size) + std::optional& + value_cache, // (num_blocks, num_heads, block_size, head_size) + torch::Tensor& slot_mapping); //(num_tokens) + +void batch_prefill(torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& output, + std::optional& output_lse, + const std::optional& q_cu_seq_lens, + const std::optional& kv_cu_seq_lens, + const std::optional& alibi_slope, + const std::optional& attn_bias, + const std::optional& q_quant_scale, + const std::optional& k_quant_scale, + const std::optional& v_quant_scale, + int64_t max_query_len, + int64_t max_seq_len, + float scale, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + const std::string& compute_dtype, + bool return_lse); + +void batch_decode(torch::Tensor& query, + torch::Tensor& k_cache, + torch::Tensor& output, + torch::Tensor& block_table, + torch::Tensor& seq_lens, + const std::optional& v_cache, + std::optional& output_lse, + const std::optional& q_quant_scale, + const std::optional& k_cache_quant_scale, + const std::optional& v_cache_quant_scale, + const std::optional& out_quant_scale, + const std::optional& alibi_slope, + const std::optional& mask, + const std::string& compute_dtype, + int64_t max_seq_len, + int64_t window_size_left, + int64_t window_size_right, + float scale, + bool return_lse, + bool is_causal, + int64_t kv_cache_quant_bit_size); + +void residual_layer_norm(torch::Tensor& input, + torch::Tensor& output, + std::optional& residual, + torch::Tensor& weight, + std::optional& beta, + std::optional& bias, + std::optional& residual_out, + double eps); + +torch::Tensor matmul(torch::Tensor a, + torch::Tensor b, + std::optional bias); + +} // namespace xllm::kernel::ilu diff --git a/xllm/core/kernels/ilu/ixformer.h b/xllm/core/kernels/ilu/ixformer.h new file mode 100644 index 000000000..1ce315377 --- /dev/null +++ b/xllm/core/kernels/ilu/ixformer.h @@ -0,0 +1,86 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "ATen/Tensor.h" +#include "utils.h" + +namespace ixformer::infer { +at::Tensor ixinfer_flash_attn_unpad( + at::Tensor& query, + at::Tensor& key, + at::Tensor& value, + at::Tensor& out, + at::Tensor& cu_seq_q, + at::Tensor& cu_seq_k, + int64_t max_seq_q, + int64_t max_seq_k, + bool is_causal, + int64_t window_left, + int64_t window_right, + double scale, + double softcap, + bool sqrt_alibi, + const c10::optional& alibi_slopes, + const c10::optional& sinks, + c10::optional& lse); + +void silu_and_mul(at::Tensor& input, at::Tensor& output); + +at::Tensor vllm_paged_attention(at::Tensor& out, + at::Tensor& query, + at::Tensor& key_cache, + at::Tensor& value_cache, + int64_t num_kv_heads, + double scale, + at::Tensor& block_tables, + at::Tensor& context_lens, + int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + bool causal, + int window_left, + int window_right, + double softcap, + bool enable_cuda_graph, + bool use_sqrt_alibi, + const c10::optional& sinks); + +void vllm_reshape_and_cache(at::Tensor& key, + at::Tensor& value, + at::Tensor& key_cache, + at::Tensor& value_cache, + at::Tensor& slot_mapping, + int64_t key_token_stride, + int64_t value_token_stride); + +void vllm_rotary_embedding(at::Tensor& positions, + at::Tensor& query, + at::Tensor& key, + int64_t head_size, + at::Tensor& cos_sin_cache, + bool is_neox); + +void residual_layer_norm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& weight, + at::Tensor& bias, + c10::optional& fused_bias, + c10::optional& output, + c10::optional& residual_output, + double alpha, + double eps, + bool is_post); +} // namespace ixformer::infer diff --git a/xllm/core/kernels/ilu/matmul.cpp b/xllm/core/kernels/ilu/matmul.cpp new file mode 100644 index 000000000..fbb28b38b --- /dev/null +++ b/xllm/core/kernels/ilu/matmul.cpp @@ -0,0 +1,27 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ilu_ops_api.h" + +namespace xllm::kernel::ilu { + +torch::Tensor matmul(torch::Tensor a, + torch::Tensor b, + std::optional bias) { + namespace F = torch::nn::functional; + return F::linear(a, b, bias.value_or(torch::Tensor())); +} + +} // namespace xllm::kernel::ilu \ No newline at end of file diff --git a/xllm/core/kernels/ilu/norm.cpp b/xllm/core/kernels/ilu/norm.cpp new file mode 100644 index 000000000..29ecd73f7 --- /dev/null +++ b/xllm/core/kernels/ilu/norm.cpp @@ -0,0 +1,49 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ilu_ops_api.h" +#include "utils.h" + +using namespace ixformer; + +namespace xllm::kernel::ilu { + +void residual_layer_norm(torch::Tensor& input, + torch::Tensor& output, + std::optional& residual, + torch::Tensor& weight, + std::optional& beta, + std::optional& bias, + std::optional& residual_out, + double eps) { + torch::ScalarType scalar_type = input.scalar_type(); + int hidden_size = weight.numel(); + torch::Tensor beta_ = beta.value_or(at::zeros( + {hidden_size}, + at::TensorOptions().dtype(input.scalar_type()).device(input.device()))); + auto residual_ = residual.value_or(torch::zeros_like(input)); + std::optional output_ = output; + infer::residual_layer_norm(input, + residual_, + weight, + beta_, + bias, + output_, + residual_out, + 1.0, + eps, + false); +} +} // namespace xllm::kernel::ilu \ No newline at end of file diff --git a/xllm/core/kernels/ilu/rope.cpp b/xllm/core/kernels/ilu/rope.cpp new file mode 100644 index 000000000..c735350ce --- /dev/null +++ b/xllm/core/kernels/ilu/rope.cpp @@ -0,0 +1,31 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "ilu_ops_api.h" +#include "utils.h" + +namespace xllm::kernel::ilu { + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& cos_sin_cache, + torch::Tensor& positions, + bool interleave) { + const int64_t head_size = cos_sin_cache.size(-1) / 2; + infer::vllm_rotary_embedding( + positions, query, key, head_size, cos_sin_cache, !interleave); +} + +} // namespace xllm::kernel::ilu diff --git a/xllm/core/kernels/ilu/utils.h b/xllm/core/kernels/ilu/utils.h new file mode 100644 index 000000000..e8af0c3c9 --- /dev/null +++ b/xllm/core/kernels/ilu/utils.h @@ -0,0 +1,63 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once +namespace xllm::kernel::ilu { +#undef check_tensor_contiguous +#define check_tensor_contiguous(x, type) \ + TORCH_CHECK(x.scalar_type() == type); \ + TORCH_CHECK(x.is_cuda()); \ + TORCH_CHECK(x.is_contiguous()); + +#undef check_tensor_half_bf_float +#define check_tensor_half_bf_float(x) \ + TORCH_CHECK(x.scalar_type() == at::ScalarType::Half || \ + x.scalar_type() == at::ScalarType::Float || \ + x.scalar_type() == at::ScalarType::BFloat16); \ + TORCH_CHECK(x.is_cuda()); + +// from torchCheckMsgImpl +inline const char* ixformer_check_msg_impl(const char* msg) { return msg; } +// // If there is just 1 user-provided C-string argument, use it. + +#define IXFORMER_CHECK_MSG(cond, type, ...) \ + (ixformer_check_msg_impl( \ + "Expected " #cond \ + " to be true, but got false. " \ + "(Could this error message be improved? If so, " \ + "please report an enhancement request to ixformer.)", \ + ##__VA_ARGS__)) + +#define IXFORMER_CHECK(cond, ...) \ + { \ + if (!(cond)) { \ + std::cerr << __FILE__ << " (" << __LINE__ << ")" \ + << "-" << __FUNCTION__ << " : " \ + << IXFORMER_CHECK_MSG(cond, "", ##__VA_ARGS__) << std::endl; \ + throw std::runtime_error("IXFORMER_CHECK ERROR"); \ + } \ + } + +#undef CUINFER_CHECK +#define CUINFER_CHECK(func) \ + do { \ + cuinferStatus_t status = (func); \ + if (status != CUINFER_STATUS_SUCCESS) { \ + std::cerr << "Error in file " << __FILE__ << " on line " << __LINE__ \ + << ": " << cuinferGetErrorString(status) << std::endl; \ + throw std::runtime_error("CUINFER_CHECK ERROR"); \ + } \ + } while (0) + +} // namespace xllm::kernel::ilu \ No newline at end of file diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 07ab88f62..ac19282ac 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -19,6 +19,8 @@ limitations under the License. #include "mlu/mlu_ops_api.h" #elif defined(USE_CUDA) #include "cuda/cuda_ops_api.h" +#elif defined(USE_ILU) +#include "ilu/ilu_ops_api.h" #endif #include @@ -48,6 +50,13 @@ void apply_rotary(RotaryParams& params) { auto cos_sin = torch::cat({cos, sin}, -1); cuda::rotary_embedding(pos_ids, params.q, params.k, cos_sin, is_neox); +#elif defined(USE_ILU) + torch::Tensor long_position_ids = params.position_ids.value().to(at::kLong); + ilu::apply_rope_pos_ids_cos_sin_cache(params.q, + params.k, + params.cos_sin, + long_position_ids, + params.interleaved); #else LOG(FATAL) << "apply_rotary not implemented"; #endif @@ -65,6 +74,8 @@ void active(ActivationParams& params) { params.expert_size); #elif defined(USE_CUDA) cuda::act_and_mul(params.output, params.input, params.act_mode); +#elif defined(USE_ILU) + ilu::act_and_mul(params.output, params.input, params.act_mode); #else LOG(FATAL) << "active not implemented"; #endif @@ -84,6 +95,13 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) { params.value.value_or(torch::Tensor()), params.k_cache, params.v_cache.value_or(torch::Tensor())); +#elif defined(USE_ILU) + // auto v_cache = params.v_cache.value_or(torch::Tensor()); + ilu::reshape_paged_cache(params.key, + params.value, + params.k_cache, + params.v_cache, + params.slot_mapping); #else LOG(FATAL) << "reshape_paged_cache not implemented"; #endif @@ -127,6 +145,27 @@ void batch_prefill(AttentionParams& params) { params.output, params.output_lse, params.enable_cuda_graph); +#elif defined(USE_ILU) + ilu::batch_prefill(params.query, + params.key, + params.value, + params.output, + params.output_lse, + params.q_cu_seq_lens, + params.kv_cu_seq_lens, + params.alibi_slope, + params.attn_bias, + params.q_quant_scale, + params.k_quant_scale, + params.v_quant_scale, + params.max_query_len, + params.max_seq_len, + params.scale, + params.is_causal, + params.window_size_left, + params.window_size_right, + params.compute_dtype, + params.return_lse); #else LOG(FATAL) << "batch_prefill not implemented"; #endif @@ -171,6 +210,28 @@ void batch_decode(AttentionParams& params) { params.output, params.output_lse, params.enable_cuda_graph); +#elif defined(USE_ILU) + ilu::batch_decode(params.query, + params.k_cache, + params.output, + params.block_table.value(), + params.kv_seq_lens, + params.v_cache, + params.output_lse, + params.q_quant_scale, + params.k_cache_quant_scale, + params.v_cache_quant_scale, + params.out_quant_scale, + params.alibi_slope, + params.mask, + params.compute_dtype, + params.max_seq_len, + params.window_size_left, + params.window_size_right, + params.scale, + params.return_lse, + params.is_causal, + params.kv_cache_quant_bit_size); #else LOG(FATAL) << "batch_decode not implemented"; #endif @@ -202,6 +263,15 @@ void fused_layernorm(FusedLayerNormParams& params) { } else { cuda::rms_norm(params.output, params.input, params.weight, params.eps); } +#elif defined(USE_ILU) + ilu::residual_layer_norm(params.input, + params.output, + params.residual, + params.weight, + params.beta, // weight_bias + params.bias, // residual_bias + params.residual_out, + params.eps); #else LOG(FATAL) << "fused_layernorm not implemented"; #endif @@ -213,6 +283,8 @@ torch::Tensor matmul(MatmulParams& params) { params.a, params.b, params.bias, params.c, params.alpha, params.beta); #elif defined(USE_CUDA) return cuda::matmul(params.a, params.b, params.bias); +#elif defined(USE_ILU) + return ilu::matmul(params.a, params.b, params.bias); #else LOG(FATAL) << "matmul not implemented"; #endif diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 49f8ceec0..685a408bf 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -6,12 +6,15 @@ if(USE_NPU) ) add_subdirectory(npu) else() + add_subdirectory(common) if(USE_MLU) add_subdirectory(mlu) + elseif(USE_ILU) + add_subdirectory(ilu) else() add_subdirectory(cuda) endif() - add_subdirectory(common) + endif() cc_library( @@ -77,6 +80,7 @@ cc_library( $<$:npu_layers> $<$:mlu_layers> $<$:cuda_layers> + $<$:ilu_layers> :parallel_state :state_dict :kv_cache diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index 992fc9749..29ac9c88d 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -25,6 +25,8 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" +#elif defined(USE_ILU) +#include "../ilu/attention.h" #endif #include "framework/model/model_input_params.h" #include "framework/parallel_state/parallel_args.h" diff --git a/xllm/core/layers/common/qwen2_attention.h b/xllm/core/layers/common/qwen2_attention.h index 839072f3d..159634514 100644 --- a/xllm/core/layers/common/qwen2_attention.h +++ b/xllm/core/layers/common/qwen2_attention.h @@ -21,6 +21,8 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" +#elif defined(USE_ILU) +#include "../ilu/attention.h" #endif #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_args.h" diff --git a/xllm/core/layers/common/rotary_embedding.h b/xllm/core/layers/common/rotary_embedding.h index f7147e5be..7c2cb93fc 100644 --- a/xllm/core/layers/common/rotary_embedding.h +++ b/xllm/core/layers/common/rotary_embedding.h @@ -24,6 +24,8 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" +#elif defined(USE_ILU) +#include "../ilu/attention.h" #endif #include "framework/model/model_args.h" #include "layers/rotary_embedding.h" diff --git a/xllm/core/layers/common/word_embedding_impl.h b/xllm/core/layers/common/word_embedding_impl.h index a749c8e6a..31c0275d3 100644 --- a/xllm/core/layers/common/word_embedding_impl.h +++ b/xllm/core/layers/common/word_embedding_impl.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include diff --git a/xllm/core/layers/ilu/CMakeLists.txt b/xllm/core/layers/ilu/CMakeLists.txt new file mode 100755 index 000000000..822ad32d5 --- /dev/null +++ b/xllm/core/layers/ilu/CMakeLists.txt @@ -0,0 +1,12 @@ +include(cc_library) + +cc_library( + NAME + ilu_layers + HDRS + attention.h + SRCS + attention.cpp + DEPS + :common_layers +) diff --git a/xllm/core/layers/ilu/attention.cpp b/xllm/core/layers/ilu/attention.cpp new file mode 100644 index 000000000..dbeb9e1b5 --- /dev/null +++ b/xllm/core/layers/ilu/attention.cpp @@ -0,0 +1,206 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "attention.h" + +#include "kernels/ops_api.h" + +namespace xllm { +namespace layer { + +AttentionMetadata AttentionMetadata::build(const ModelInputParams& params, + bool is_prefill) { + return AttentionMetadata::build(params, "float", is_prefill); +} + +AttentionMetadata AttentionMetadata::build(const ModelInputParams& params, + const std::string& compute_dtype, + bool is_prefill) { + AttentionMetadata attn_metadata; + attn_metadata.query_start_loc = params.q_seq_lens; + attn_metadata.seq_start_loc = params.kv_seq_lens; + attn_metadata.max_query_len = params.q_max_seq_len; + attn_metadata.max_seq_len = params.kv_max_seq_len; + attn_metadata.slot_mapping = params.new_cache_slots; + attn_metadata.compute_dtype = compute_dtype; + attn_metadata.q_cu_seq_lens = params.q_seq_lens; + attn_metadata.kv_cu_seq_lens = params.kv_seq_lens; // cumulative kv seqlens + + bool is_start_loc_match = (params.q_seq_lens_vec == params.kv_seq_lens_vec); + attn_metadata.is_chunked_prefill = is_prefill && !is_start_loc_match; + attn_metadata.is_prefill = is_prefill && !attn_metadata.is_chunked_prefill; + if (!attn_metadata.is_prefill || FLAGS_enable_mla) { + attn_metadata.block_table = params.block_tables; + attn_metadata.kv_seq_lens = torch::diff(params.kv_seq_lens); // kv seqlens + } + + attn_metadata.is_dummy = (params.q_max_seq_len == 0); + + return attn_metadata; +} + +AttentionImpl::AttentionImpl(int64_t num_heads, + int64_t head_size, + float scale, + int64_t num_kv_heads, + int64_t sliding_window) + : num_heads_(num_heads), + head_size_(head_size), + num_kv_heads_(num_kv_heads), + v_head_dim_(head_size), + sliding_window_(sliding_window), + scale_(scale), + use_fused_mla_qkv_(false), + enable_lighting_indexer_(false), + enable_mla_(false) { + if (sliding_window_ > -1) { + sliding_window_ = sliding_window_ - 1; + } +} + +AttentionImpl::AttentionImpl(int64_t num_heads, + int64_t head_size, + int64_t num_kv_heads, + int64_t v_head_dim, + int64_t sliding_window, + float scale, + bool use_fused_mla_qkv, + bool enable_lighting_indexer) + : num_heads_(num_heads), + head_size_(head_size), + num_kv_heads_(num_kv_heads), + v_head_dim_(v_head_dim), + sliding_window_(sliding_window), + use_fused_mla_qkv_(use_fused_mla_qkv), + scale_(scale), + enable_lighting_indexer_(enable_lighting_indexer), + enable_mla_(FLAGS_enable_mla) { + if (sliding_window_ > -1) { + sliding_window_ = sliding_window_ - 1; + } +} + +std::tuple> AttentionImpl::forward( + const AttentionMetadata& attn_metadata, + torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& value, + KVCache& kv_cache) { + std::optional output_lse = std::nullopt; + torch::Tensor output; + if (enable_mla_) { + output = torch::empty({query.size(0), num_heads_ * v_head_dim_}, + query.options()); + } else { + output = torch::empty_like(query); + } + if (attn_metadata.is_dummy) { + return std::make_tuple(output, output_lse); + } + + bool only_prefill = + attn_metadata.is_prefill || attn_metadata.is_chunked_prefill; + int64_t num_kv_heads = (enable_mla_ && !only_prefill) ? 1 : num_kv_heads_; + torch::Tensor k_cache = kv_cache.get_k_cache(); + std::optional v_cache; + std::optional v; + if (!enable_mla_) { + v = value.view({-1, num_kv_heads, head_size_}); + v_cache = kv_cache.get_v_cache(); + } + + bool skip_process_cache = enable_mla_ && (only_prefill || use_fused_mla_qkv_); + if (!skip_process_cache) { + xllm::kernel::ReshapePagedCacheParams reshape_paged_cache_params; + reshape_paged_cache_params.key = key.view({-1, num_kv_heads, head_size_}); + reshape_paged_cache_params.value = v; + reshape_paged_cache_params.k_cache = k_cache; + reshape_paged_cache_params.v_cache = v_cache; + reshape_paged_cache_params.slot_mapping = attn_metadata.slot_mapping; + xllm::kernel::reshape_paged_cache(reshape_paged_cache_params); + } + + if (enable_lighting_indexer_ || !only_prefill) { + decoder_forward(query, output, k_cache, v_cache, attn_metadata); + } else { + prefill_forward(query, key, value, output, k_cache, v_cache, attn_metadata); + } + + int64_t head_size = enable_mla_ ? v_head_dim_ : head_size_; + output = output.view({-1, num_heads_ * head_size}); + return {output, output_lse}; +} + +void AttentionImpl::prefill_forward(torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& output, + const torch::Tensor& k_cache, + const std::optional& v_cache, + const AttentionMetadata& attn_metadata) { + int64_t head_size_v = enable_mla_ ? v_head_dim_ : head_size_; + xllm::kernel::AttentionParams attention_params; + attention_params.query = query.view({-1, num_heads_, head_size_}); + attention_params.output = output.view({-1, num_heads_, head_size_v}); + attention_params.max_seq_len = attn_metadata.max_seq_len; + attention_params.window_size_left = sliding_window_; + attention_params.scale = scale_; + attention_params.compute_dtype = attn_metadata.compute_dtype; + + attention_params.query_start_loc = attn_metadata.query_start_loc; + attention_params.seq_start_loc = attn_metadata.seq_start_loc; + attention_params.max_query_len = attn_metadata.max_query_len; + + if (attn_metadata.is_prefill) { + attention_params.key = key.view({-1, num_kv_heads_, head_size_}); + attention_params.value = value.view({-1, num_kv_heads_, head_size_v}); + attention_params.block_table = std::nullopt; + } else if (attn_metadata.is_chunked_prefill) { + attention_params.key = k_cache; + attention_params.value = v_cache.value(); + attention_params.block_table = attn_metadata.block_table; + } + + attention_params.kv_cu_seq_lens = attn_metadata.kv_cu_seq_lens; + attention_params.q_cu_seq_lens = attn_metadata.q_cu_seq_lens; + xllm::kernel::batch_prefill(attention_params); +} + +void AttentionImpl::decoder_forward(torch::Tensor& query, + torch::Tensor& output, + const torch::Tensor& k_cache, + const std::optional& v_cache, + const AttentionMetadata& attn_metadata) { + int64_t head_size_v = enable_mla_ ? v_head_dim_ : head_size_; + xllm::kernel::AttentionParams attention_params; + attention_params.query = query.view({-1, 1, num_heads_, head_size_}); + attention_params.output = output.view({-1, 1, num_heads_, head_size_v}); + attention_params.output_lse = std::nullopt; + attention_params.max_seq_len = attn_metadata.max_seq_len; + attention_params.window_size_left = sliding_window_; + attention_params.scale = scale_; + attention_params.compute_dtype = attn_metadata.compute_dtype; + attention_params.k_cache = k_cache; + attention_params.v_cache = v_cache; + + // for mlu + attention_params.block_table = attn_metadata.block_table; + attention_params.kv_seq_lens = attn_metadata.kv_seq_lens; + + xllm::kernel::batch_decode(attention_params); +} + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/ilu/attention.h b/xllm/core/layers/ilu/attention.h new file mode 100644 index 000000000..d14d08c2c --- /dev/null +++ b/xllm/core/layers/ilu/attention.h @@ -0,0 +1,110 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_input_params.h" + +namespace xllm { +namespace layer { + +struct AttentionMetadata { + public: + static AttentionMetadata build(const ModelInputParams& params, + bool is_prefill); + + static AttentionMetadata build(const ModelInputParams& params, + const std::string& compute_dtype, + bool is_prefill); + + torch::Tensor query_start_loc; + torch::Tensor seq_start_loc; + torch::Tensor kv_seq_lens; + torch::Tensor block_table; + torch::Tensor slot_mapping; + int64_t max_query_len; + int64_t max_seq_len; + std::string compute_dtype; + bool is_prefill; + bool is_chunked_prefill; + bool is_dummy; + + // for mrope + torch::Tensor mrope_cos; + torch::Tensor mrope_sin; + + torch::Tensor q_cu_seq_lens; + torch::Tensor kv_cu_seq_lens; +}; + +class AttentionImpl : public torch::nn::Module { + public: + AttentionImpl() = default; + + AttentionImpl(int64_t num_heads, + int64_t head_size, + float scale, + int64_t num_kv_heads, + int64_t sliding_window); + AttentionImpl(int64_t num_heads, + int64_t head_size, + int64_t num_kv_heads, + int64_t v_head_dim, + int64_t sliding_window, + float scale, + bool use_fused_mla_qkv, + bool enable_lighting_indexer); + + std::tuple> forward( + const AttentionMetadata& attn_metadata, + torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& value, + KVCache& kv_cache); + + void prefill_forward(torch::Tensor& query, + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& output, + const torch::Tensor& k_cache, + const std::optional& v_cache, + const AttentionMetadata& attn_metadata); + + void decoder_forward(torch::Tensor& query, + torch::Tensor& output, + const torch::Tensor& k_cache, + const std::optional& v_cache, + const AttentionMetadata& attn_metadata); + + private: + int64_t num_heads_; + int64_t head_size_; + float scale_; + int64_t num_kv_heads_; + int64_t v_head_dim_; + bool use_fused_mla_qkv_; + bool enable_mla_; + bool enable_lighting_indexer_; + int64_t sliding_window_; +}; +TORCH_MODULE(Attention); + +} // namespace layer +} // namespace xllm diff --git a/xllm/core/platform/CMakeLists.txt b/xllm/core/platform/CMakeLists.txt index 3549c6161..4c0e6fb8a 100644 --- a/xllm/core/platform/CMakeLists.txt +++ b/xllm/core/platform/CMakeLists.txt @@ -18,8 +18,8 @@ cc_library( $<$:torch_mlu> $<$:cnrt> $<$:cndrv> - $<$:cuda> - $<$:cudart> + $<$,$>:cuda> + $<$,$>:cudart> ) if(USE_NPU) diff --git a/xllm/core/platform/device.cpp b/xllm/core/platform/device.cpp index 41e432557..26ae85de5 100644 --- a/xllm/core/platform/device.cpp +++ b/xllm/core/platform/device.cpp @@ -21,7 +21,7 @@ limitations under the License. #include #include #include -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) #include #include #endif @@ -37,7 +37,7 @@ void Device::set_device() const { c10_npu::set_device(index()); #elif defined(USE_MLU) torch_mlu::setDevice(index()); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) c10::cuda::set_device(index()); #endif } @@ -74,7 +74,7 @@ int Device::device_count() { return c10_npu::device_count(); #elif defined(USE_MLU) return torch_mlu::device_count(); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) return c10::cuda::device_count(); #endif } @@ -84,7 +84,7 @@ std::string Device::type_str() { return "npu"; #elif defined(USE_MLU) return "mlu"; -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) return "cuda"; #endif } @@ -92,7 +92,7 @@ std::string Device::type_str() { torch::DeviceType Device::type_torch() { #if defined(USE_NPU) || defined(USE_MLU) return torch::kPrivateUse1; -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) return torch::kCUDA; #endif } @@ -106,7 +106,7 @@ Device::DeviceMem Device::get_device_mem() const { aclrtGetMemInfo(ACL_HBM_MEM, &free_memory, &total_memory); #elif defined(USE_MLU) cnrtMemGetInfo(&free_memory, &total_memory); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) cudaMemGetInfo(&free_memory, &total_memory); #endif device_mem.total_memory = static_cast(total_memory); @@ -123,7 +123,7 @@ int Device::synchronize_default_stream() { return aclrtSynchronizeStream(c10_npu::getCurrentNPUStream(index()).stream()); #elif defined(USE_MLU) torch_mlu::getCurrentMLUStream(index()).synchronize(); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) c10::cuda::getCurrentCUDAStream().synchronize(); #endif return 0; diff --git a/xllm/core/platform/stream.cpp b/xllm/core/platform/stream.cpp index 90dce066e..e13a2e2bc 100644 --- a/xllm/core/platform/stream.cpp +++ b/xllm/core/platform/stream.cpp @@ -23,7 +23,7 @@ Stream::Stream(const int32_t timeout) #elif defined(USE_MLU) Stream::Stream(const int32_t timeout) : stream_(torch_mlu::getStreamFromPool()), timeout_(timeout) {} -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) Stream::Stream(const int32_t timeout) : stream_(c10::cuda::getStreamFromPool()), timeout_(timeout) {} #endif @@ -34,7 +34,7 @@ int Stream::synchronize() const { #elif defined(USE_MLU) stream_.unwrap().synchronize(); return 0; -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) stream_.synchronize(); return 0; #else @@ -44,7 +44,7 @@ int Stream::synchronize() const { } c10::StreamGuard Stream::set_stream_guard() const { -#if defined(USE_CUDA) +#if defined(USE_CUDA) || defined(USE_ILU) return c10::StreamGuard(stream_); #else return c10::StreamGuard(stream_.unwrap()); diff --git a/xllm/core/platform/stream.h b/xllm/core/platform/stream.h index ecd8bd5e8..128b8517f 100644 --- a/xllm/core/platform/stream.h +++ b/xllm/core/platform/stream.h @@ -30,7 +30,7 @@ limitations under the License. #include #elif defined(USE_MLU) #include -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) #include #endif @@ -59,7 +59,7 @@ class Stream { c10_npu::NPUStream stream_; #elif defined(USE_MLU) torch_mlu::MLUStream stream_; -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) c10::cuda::CUDAStream stream_; #endif const int32_t timeout_; diff --git a/xllm/core/platform/vmm_api.cpp b/xllm/core/platform/vmm_api.cpp index 098034345..2a93c1a63 100644 --- a/xllm/core/platform/vmm_api.cpp +++ b/xllm/core/platform/vmm_api.cpp @@ -67,7 +67,7 @@ void create_phy_mem_handle(PhyMemHandle& phy_mem_handle, int32_t device_id) { accessDesc.location.id = device_id; accessDesc.accessFlags = CN_MEM_ACCESS_FLAGS_PROT_READWRITE; ret = cnMemSetAccess(phy_mem_handle, granularity_size, &accessDesc, 1); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) CUmemAllocationProp prop = {}; prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; @@ -97,7 +97,7 @@ void create_vir_ptr(VirPtr& vir_ptr, size_t aligned_size) { ret = aclrtReserveMemAddress(&vir_ptr, aligned_size, 0, nullptr, 0); #elif defined(USE_MLU) ret = cnMemAddressReserve(&vir_ptr, aligned_size, 0, 0, 0); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) ret = cuMemAddressReserve(&vir_ptr, aligned_size, 0, 0, 0); #endif CHECK_EQ(ret, 0) << "Failed to create virtual memory handle"; @@ -109,7 +109,7 @@ void release_phy_mem_handle(PhyMemHandle& phy_mem_handle) { ret = aclrtFreePhysical(phy_mem_handle); #elif defined(USE_MLU) ret = cnMemRelease(phy_mem_handle); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) ret = cuMemRelease(phy_mem_handle); #endif CHECK_EQ(ret, 0) << "Failed to release physical memory handle"; @@ -121,7 +121,7 @@ void release_vir_ptr(VirPtr& vir_ptr, size_t aligned_size) { ret = aclrtReleaseMemAddress(vir_ptr); #elif defined(USE_MLU) ret = cnMemAddressFree(vir_ptr, aligned_size); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) ret = cuMemAddressFree(vir_ptr, aligned_size); #endif CHECK_EQ(ret, 0) << "Failed to release virtual memory handle"; @@ -135,7 +135,7 @@ void map(VirPtr& vir_ptr, PhyMemHandle& phy_mem_handle) { #elif defined(USE_MLU) ret = cnMemMap(vir_ptr, FLAGS_phy_page_granularity_size, 0, phy_mem_handle, 0); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) ret = cuMemMap(vir_ptr, FLAGS_phy_page_granularity_size, 0, phy_mem_handle, 0); #endif @@ -148,7 +148,7 @@ void unmap(VirPtr& vir_ptr, size_t aligned_size) { ret = aclrtUnmapMem(vir_ptr); #elif defined(USE_MLU) ret = cnMemUnmap(vir_ptr, aligned_size); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) ret = cuMemUnmap(vir_ptr, aligned_size); #endif CHECK_EQ(ret, 0) << "Failed to unmap virtual memory from physical memory"; diff --git a/xllm/core/platform/vmm_api.h b/xllm/core/platform/vmm_api.h index 8a6b1444c..321fa2a9e 100644 --- a/xllm/core/platform/vmm_api.h +++ b/xllm/core/platform/vmm_api.h @@ -19,7 +19,7 @@ limitations under the License. #include "acl/acl.h" #elif defined(USE_MLU) #include -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) #include #endif @@ -31,7 +31,7 @@ using PhyMemHandle = aclrtDrvMemHandle; #elif defined(USE_MLU) using VirPtr = CNaddr; using PhyMemHandle = CNmemGenericAllocationHandle; -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) using VirPtr = CUdeviceptr; using PhyMemHandle = CUmemGenericAllocationHandle; #endif diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 047c2a828..3f52fc136 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -32,7 +32,7 @@ limitations under the License. #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" #include "framework/state_dict/state_dict.h" -#if defined(USE_CUDA) +#if defined(USE_CUDA) || defined(USE_ILU) #include "layers/cuda/flashinfer_workspace.h" #endif #include "models/model_registry.h" diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 0bc187fbe..aa1ccef04 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -25,7 +25,7 @@ limitations under the License. #include "kernels/npu/xllm_ops/replace_token.h" #elif defined(USE_MLU) #include -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) #include #endif @@ -351,7 +351,7 @@ std::tuple WorkerImpl::estimate_kv_cache_capacity() { device_id, &torch_cache, &torch_largest_block); #elif defined(USE_MLU) torch_mlu::MLUCachingAllocator::emptyCache(); -#elif defined(USE_CUDA) +#elif defined(USE_CUDA) || defined(USE_ILU) c10::cuda::CUDACachingAllocator::emptyCache(); #endif const auto available_memory = device_.free_memory(); diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 6c380cde0..07aa7549b 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -44,6 +44,9 @@ limitations under the License. #if defined(USE_CUDA) #include "core/layers/cuda/attention.h" #endif +#if defined(USE_ILU) +#include "core/layers/ilu/attention.h" +#endif #if defined(USE_MLU) #include "core/layers/mlu/attention.h" #endif