Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set_property(GLOBAL PROPERTY USE_FOLDERS ON)

option(USE_NPU "Enable NPU support" OFF)
option(USE_MLU "Enable MLU support" OFF)
option(USE_ILU "Enable ILU support" OFF)
option(USE_CUDA "Enable CUDA support" OFF)
add_compile_definitions(YLT_ENABLE_IBV)
add_definitions(-DYLT_ENABLE_IBV)
Expand Down Expand Up @@ -105,7 +106,7 @@ set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)

if(USE_NPU)
if(USE_NPU OR USE_ILU)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
elseif(USE_MLU OR USE_CUDA)
Expand Down Expand Up @@ -208,6 +209,19 @@ if(USE_CUDA)
message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
endif()

if(USE_ILU)
set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/cmake/Modules;${CMAKE_MODULE_PATH}")
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CUDA_ARCHITECTURES "ivcore11")
set(WARNINGS_AS_ERRORS OFF)
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_definitions(
-Wno-c++11-narrowing
-Wno-thread-safety-analysis
)
endif()
endif()

# configure vcpkg
# have to set CMAKE_TOOLCHAIN_FILE before first project call.
# if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE)
Expand Down Expand Up @@ -425,6 +439,23 @@ if(USE_CUDA)
)
endif()

if(USE_ILU)
add_definitions(-DUSE_ILU)
set(CMAKE_VERBOSE_MAKEFILE ON)
include_directories(
$ENV{PYTHON_INCLUDE_PATH}
$ENV{PYTORCH_INSTALL_PATH}/include
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
$ENV{IXFORMER_INSTALL_PATH}/csrc/include/ixformer
)

link_directories(
$ENV{PYTHON_LIB_PATH}
$ENV{PYTORCH_INSTALL_PATH}/lib
$ENV{IXFORMER_INSTALL_PATH}
)
endif()

# check if USE_CXX11_ABI is set correctly
# if (DEFINED USE_CXX11_ABI)
# parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")
Expand Down
33 changes: 29 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ def get_device_type():
if torch.cuda.is_available():
return "cuda"

try:
import ixformer
return "ilu"
except ImportError:
pass

try:
import torch_mlu
if torch.mlu.is_available():
Expand Down Expand Up @@ -143,6 +149,14 @@ def get_torch_mlu_root_path():
except ImportError:
return None

def get_ixformer_root_path():
try:
import ixformer
import os
return os.path.dirname(os.path.abspath(ixformer.__file__))
except ImportError:
return None

def get_nccl_root_path():
try:
from nvidia import nccl
Expand Down Expand Up @@ -253,7 +267,14 @@ def set_cuda_envs():
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
os.environ["CUDA_TOOLKIT_ROOT_DIR"] = "/usr/local/cuda"


def set_ilu_envs():
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
os.environ["IXFORMER_INSTALL_PATH"] = get_ixformer_root_path()

class CMakeExtension(Extension):
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
Expand Down Expand Up @@ -337,7 +358,7 @@ def build_extension(self, ext: CMakeExtension):
f"-DDEVICE_ARCH={self.arch.upper()}",
f"-DINSTALL_XLLM_KERNELS={'ON' if self.install_xllm_kernels else 'OFF'}",
]

if self.device == "a2" or self.device == "a3":
cmake_args += ["-DUSE_NPU=ON"]
# set npu environment variables
Expand All @@ -352,6 +373,9 @@ def build_extension(self, ext: CMakeExtension):
f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"]
# set cuda environment variables
set_cuda_envs()
elif self.device == "ilu":
cmake_args += ["-DUSE_ILU=ON"]
set_ilu_envs()
else:
raise ValueError("Please set --device to a2 or a3 or mlu or cuda.")

Expand All @@ -375,6 +399,7 @@ def build_extension(self, ext: CMakeExtension):

build_args = ["--config", build_type]
max_jobs = os.getenv("MAX_JOBS", str(os.cpu_count()))
# max_jobs="2"
build_args += ["-j" + max_jobs]

env = os.environ.copy()
Expand Down Expand Up @@ -604,9 +629,9 @@ def parse_arguments():
parser.add_argument(
'--device',
type=str.lower,
choices=['auto', 'a2', 'a3', 'mlu', 'cuda'],
choices=['auto', 'a2', 'a3', 'mlu', 'cuda', 'ilu'],
default='auto',
help='Device type: a2, a3, mlu, or cuda (case-insensitive)'
help='Device type: a2, a3, mlu, ilu or cuda (case-insensitive)'
)

parser.add_argument(
Expand Down
28 changes: 28 additions & 0 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,32 @@ target_include_directories(mooncake_store PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/Mooncake/mooncake-transfer-engine/include
)

if(USE_ILU)
if(TARGET cpprest)
set_target_properties(cpprest PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
)
endif()
if(TARGET transfer_engine)
target_compile_options(transfer_engine PRIVATE -std=c++20)
set_target_properties(transfer_engine PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
)
message(STATUS "Set C++20 for transfer_engine target")
endif()
if(TARGET SMHasherSupport)
set_target_properties(SMHasherSupport PROPERTIES
CXX_STANDARD 11
CXX_STANDARD_REQUIRED ON
CXX_EXTENSIONS OFF
)
message(STATUS "SMHasherSupport target found and configured")
else()
message(WARNING "SMHasherSupport target not found after adding smhasher")
endif()
endif()

target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator)
6 changes: 3 additions & 3 deletions xllm/core/framework/batch/batch_input_builder.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void BatchInputBuilder::process_sequences_multithreaded() {
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
state.q_seq_lens.begin(),
state.q_seq_lens.end());
#elif defined(USE_MLU) || defined(USE_CUDA)
#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU)
int32_t seq_len_offset = state_.seq_lens.back();
// skip the first element which is 0
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
Expand Down Expand Up @@ -293,7 +293,7 @@ void BatchInputBuilder::process_single_sequence(
#if defined(USE_NPU)
state.seq_lens.push_back(seq_len + offset);
state.q_seq_lens.push_back(q_seq_len);
#elif defined(USE_MLU) || defined(USE_CUDA)
#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU)
state.seq_lens.push_back(state.seq_lens.back() + seq_len + offset);
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
#endif
Expand Down Expand Up @@ -527,7 +527,7 @@ void BatchInputBuilder::padding_decode_batch_size(
#if defined(USE_NPU)
state_.seq_lens.push_back(num_decoding_tokens);
state_.q_seq_lens.push_back(num_decoding_tokens);
#elif defined(USE_MLU) || defined(USE_CUDA)
#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU)
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
num_decoding_tokens);
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class BatchInputBuilder {
#if defined(USE_NPU)
std::vector<int32_t> seq_lens;
std::vector<int32_t> q_seq_lens;
#elif defined(USE_MLU) || defined(USE_CUDA)
#elif defined(USE_MLU) || defined(USE_CUDA) || defined(USE_ILU)
std::vector<int32_t> seq_lens = {0}; // cu_seq_lens
std::vector<int32_t> q_seq_lens = {0}; // q_cu_seq_len
#endif
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/parallel_state/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cc_library(
$<$<BOOL:${USE_NPU}>:npu_process_group.h>
$<$<BOOL:${USE_MLU}>:mlu_process_group.h>
$<$<BOOL:${USE_CUDA}>:cuda_process_group.h>
$<$<BOOL:${USE_ILU}>:ilu_process_group.h>
collective_communicator.h
SRCS
mapping_npu.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ limitations under the License.
#include "mlu_process_group.h"
#elif defined(USE_CUDA)
#include "cuda_process_group.h"
#elif defined(USE_ILU)
#include "ilu_process_group.h"
#endif
#include "common/global_flags.h"
#include "parallel_args.h"
Expand Down
55 changes: 55 additions & 0 deletions xllm/core/framework/parallel_state/ilu_process_group.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>

#include "process_group.h"

namespace xllm {

class ProcessGroupImpl : public ProcessGroup {
public:
ProcessGroupImpl(int32_t global_rank,
int32_t world_size,
int32_t rank_size,
int32_t port,
bool trans,
const std::string& host,
const std::string& group_name,
const torch::Device& device)
: ProcessGroup(device) {
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> pg_options =
c10d::ProcessGroupNCCL::Options::create();
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7)
pg_options->group_name = group_name;
#endif
int32_t rank = global_rank;
if (world_size != rank_size) {
auto [local_rank, group_ranks] =
get_group_rank(world_size, global_rank, rank_size, trans);
pg_options->global_ranks_in_group = group_ranks;
rank = local_rank;
}

auto store = create_tcp_store(host, port, rank);
pg_ = std::make_unique<c10d::ProcessGroupNCCL>(
store, rank, rank_size, pg_options);
}
};

} // namespace xllm
2 changes: 2 additions & 0 deletions xllm/core/framework/parallel_state/process_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License.
#include "mlu_process_group.h"
#elif defined(USE_CUDA)
#include "cuda_process_group.h"
#elif defined(USE_ILU)
#include "ilu_process_group.h"
#endif

namespace {
Expand Down
5 changes: 5 additions & 0 deletions xllm/core/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ if(USE_CUDA)
add_subdirectory(cuda)
endif()

if(USE_ILU)
add_subdirectory(ilu)
endif()

cc_library(
NAME
kernels
Expand All @@ -25,4 +29,5 @@ cc_library(
$<$<BOOL:${USE_NPU}>:npu_kernels>
$<$<BOOL:${USE_MLU}>:mlu_kernels>
$<$<BOOL:${USE_CUDA}>:cuda_kernels>
$<$<BOOL:${USE_ILU}>:ilu_kernels>
)
28 changes: 28 additions & 0 deletions xllm/core/kernels/ilu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
include(cc_library)
set(CMAKE_CUDA_ARCHITECTURES ivcore11)
file(GLOB_RECURSE ILU_HEADER_FILES
"${CMAKE_CURRENT_LIST_DIR}/*.h"
)

file(GLOB_RECURSE ILU_SOURCE_FILES
"${CMAKE_CURRENT_LIST_DIR}/*.cpp"
"${CMAKE_CURRENT_LIST_DIR}/*.cu"
)

find_package(Python3 REQUIRED COMPONENTS Interpreter Development)

cc_library(
NAME
ilu_kernels
HDRS
${ILU_HEADER_FILES}
SRCS
${ILU_SOURCE_FILES}
DEPS
torch
:util
ixformer_kernels
ixformer
${Python3_LIBRARIES}
cuinfer
)
32 changes: 32 additions & 0 deletions xllm/core/kernels/ilu/activation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "ilu_ops_api.h"

using namespace ixformer;

namespace xllm::kernel::ilu {

void act_and_mul(torch::Tensor out,
torch::Tensor input,
const std::string& act_mode) {
if (act_mode == "silu") {
infer::silu_and_mul(input, out);
} else {
LOG(FATAL) << "Unsupported act mode: " << act_mode
<< ", only support silu, gelu, gelu_tanh";
}
}
} // namespace xllm::kernel::ilu
Loading
Loading