Skip to content

Commit 6964967

Browse files
committed
refactor: replace TORCH_CHECK with CHECK macros and optimize code layout.
1 parent d95324a commit 6964967

File tree

14 files changed

+88
-101
lines changed

14 files changed

+88
-101
lines changed

xllm/core/kernels/npu/active.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16+
#include <glog/logging.h>
1617
#include <torch_npu/csrc/aten/CustomFunctions.h>
1718

1819
#include "npu_ops_api.h"
@@ -22,8 +23,7 @@ namespace xllm::kernel::npu {
2223

2324
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) {
2425
if (act_mode != "silu" && act_mode != "swiglu") {
25-
throw std::runtime_error(
26-
"Only swiglu activation is supported in NPU active");
26+
LOG(FATAL) << "Only swiglu activation is supported in NPU active";
2727
}
2828
return at_npu::native::custom_ops::npu_swiglu(input);
2929
}

xllm/core/kernels/npu/attention.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ void batch_decode(const torch::Tensor& query,
4646
const torch::Tensor& block_table,
4747
const torch::Tensor& seq_lens,
4848
torch::Tensor& output) {
49-
auto head_size = query.size(-1);
50-
auto num_heads = query.size(-2);
51-
auto num_kv_heads = k_cache.size(-2);
49+
int64_t head_size = query.size(-1);
50+
int64_t num_heads = query.size(-2);
51+
int64_t num_kv_heads = k_cache.size(-2);
5252
auto q = query.view({-1, num_heads, head_size});
5353
auto o = output.view({-1, num_heads, head_size});
5454
atb::_npu_paged_attention(q,

xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) {
3333
{at::ScalarType::ComplexDouble, ACL_COMPLEX128},
3434
};
3535

36-
TORCH_CHECK(at_tensor.is_contiguous(), "at_tensor is not contiguous");
36+
CHECK(at_tensor.is_contiguous()) << "at_tensor is not contiguous";
3737
atb::Tensor tensor;
3838
tensor.desc.format = atb::utils::get_format_for_atb(at_tensor);
3939
if (at_tensor.device().type() == at::kCPU) {
@@ -48,9 +48,8 @@ atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) {
4848
}
4949

5050
auto dtype_iterator = dtype_map.find(at_tensor.scalar_type());
51-
TORCH_CHECK(dtype_iterator != dtype_map.end(),
52-
"not support dtype: ",
53-
at_tensor.scalar_type());
51+
CHECK(dtype_iterator != dtype_map.end())
52+
<< "not support dtype: " << at_tensor.scalar_type();
5453
tensor.desc.dtype = dtype_iterator->second;
5554

5655
tensor.dataSize = atb::Utils::GetTensorSize(tensor);
@@ -168,7 +167,7 @@ uint64_t operation_setup(atb::VariantPack variant_pack,
168167
uint64_t workspace_size = 0;
169168
atb::Status status =
170169
operation->Setup(variant_pack, workspace_size, context_ptr);
171-
TORCH_CHECK(status == 0, operation->GetName(), " setup failed!");
170+
CHECK_EQ(status, 0) << operation->GetName() << " setup failed!";
172171
return workspace_size;
173172
}
174173

xllm/core/kernels/npu/custom_functions_npu/atb_common.h

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#pragma once
1717

1818
#include <dlfcn.h>
19+
#include <glog/logging.h>
1920
#include <torch/library.h>
2021
#include <torch_npu/csrc/core/npu/NPUStream.h>
2122
#include <torch_npu/csrc/core/npu/NPUWorkspaceAllocator.h>
@@ -30,7 +31,7 @@ namespace atb {
3031

3132
using aclTensor = struct aclTensor;
3233
constexpr int64_t MAX_DIM_NUM = 5;
33-
const int N = 32;
34+
const int64_t N = 32;
3435

3536
using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims,
3637
uint64_t view_dims_num,
@@ -87,7 +88,7 @@ inline void* get_api_func_addr(const char* api_name) {
8788
if (func_addr != nullptr) {
8889
return func_addr;
8990
}
90-
TORCH_CHECK(false, "get_api_func_addr not found ", api_name);
91+
LOG(FATAL) << "get_api_func_addr not found " << api_name;
9192
}
9293
}
9394

@@ -119,8 +120,8 @@ inline aclTensor* convert_type(TensorMaintainer& maintainer,
119120
c10::SmallVector<int64_t, MAX_DIM_NUM> storageDims;
120121
// if acl_data_type is ACL_STRING, storageDims is empty.
121122
if (acl_data_type != ACL_STRING) {
122-
TORCH_CHECK(at_tensor.itemsize() > 0,
123-
"the itemsize of tensor must be greater than 0.");
123+
CHECK_GT(at_tensor.itemsize(), 0)
124+
<< "the itemsize of tensor must be greater than 0.";
124125
storageDims.push_back(at_tensor.storage().nbytes() / at_tensor.itemsize());
125126
}
126127

@@ -245,8 +246,8 @@ inline aclTensor* convert_type_v2(TensorStructPtr at_tensor) {
245246
atb::utils::convert_to_acl_data_type(scalar_data_type);
246247
c10::SmallVector<int64_t, MAX_DIM_NUM> storageDims;
247248
if (acl_data_type != ACL_STRING) {
248-
TORCH_CHECK((*at_tensor).itemsize > 0,
249-
"the itemsize of tensor must be greater than 0.");
249+
CHECK_GT((*at_tensor).itemsize, 0)
250+
<< "the itemsize of tensor must be greater than 0.";
250251
storageDims.push_back((*at_tensor).nbytes / (*at_tensor).itemsize);
251252
}
252253

@@ -349,16 +350,10 @@ void release_convert_types(Tuple& t) {
349350
static const auto getWorkspaceSizeFuncAddr = \
350351
get_api_func_addr(#atb_api "GetWorkspaceSize"); \
351352
static const auto atbApiFuncAddr = get_api_func_addr(#atb_api); \
352-
TORCH_CHECK( \
353-
getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr, \
354-
#atb_api, \
355-
" or ", \
356-
#atb_api "GetWorkspaceSize", \
357-
" not in ", \
358-
get_atb_api_lib_name(), \
359-
", or ", \
360-
get_atb_api_lib_name(), \
361-
"not found."); \
353+
CHECK(getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr) \
354+
<< #atb_api << " or " << #atb_api "GetWorkspaceSize" << " not in " \
355+
<< get_atb_api_lib_name() << ", or " << get_atb_api_lib_name() \
356+
<< "not found."; \
362357
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
363358
auto context_ptr = atb::utils::get_context(acl_stream); \
364359
uint64_t workspace_size = 0; \
@@ -374,7 +369,7 @@ void release_convert_types(Tuple& t) {
374369
static auto getWorkspaceSizeFunc = \
375370
convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \
376371
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
377-
TORCH_CHECK(workspace_status == 0, "call " #atb_api " failed, detail:"); \
372+
CHECK_EQ(workspace_status, 0) << "call " #atb_api " failed, detail:"; \
378373
void* workspace_addr = nullptr; \
379374
at::Tensor workspace_tensor; \
380375
if (workspace_size != 0) { \
@@ -395,7 +390,7 @@ void release_convert_types(Tuple& t) {
395390
AtbApiFunc atbApiFunc = reinterpret_cast<AtbApiFunc>(atbApiFuncAddr); \
396391
auto api_ret = \
397392
atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \
398-
TORCH_CHECK(api_ret == 0, "call " #atb_api " failed, detail:"); \
393+
CHECK_EQ(api_ret, 0) << "call " #atb_api " failed, detail:"; \
399394
DestroyOperation(op); \
400395
release_convert_types(converted_params); \
401396
return api_ret; \
@@ -408,16 +403,10 @@ void release_convert_types(Tuple& t) {
408403
static const auto getWorkspaceSizeFuncAddr = \
409404
get_api_func_addr(#atb_api "GetWorkspaceSize"); \
410405
static const auto AtbApiFuncAddr = get_api_func_addr(#atb_api); \
411-
TORCH_CHECK( \
412-
getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr, \
413-
#atb_api, \
414-
" or ", \
415-
#atb_api "GetWorkspaceSize", \
416-
" not in ", \
417-
get_atb_api_lib_name(), \
418-
", or ", \
419-
get_atb_api_lib_name(), \
420-
"not found."); \
406+
CHECK(getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr) \
407+
<< #atb_api << " or " << #atb_api "GetWorkspaceSize" << " not in " \
408+
<< get_atb_api_lib_name() << ", or " << get_atb_api_lib_name() \
409+
<< "not found."; \
421410
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
422411
TensorMaintainer tensor_maintainer; \
423412
auto copied_params = copy_types_v2(tensor_maintainer, __VA_ARGS__); \
@@ -440,8 +429,8 @@ void release_convert_types(Tuple& t) {
440429
convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \
441430
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
442431
opParamCache.save_operation(hash_id, op); \
443-
TORCH_CHECK(workspace_status == 0, \
444-
"call " #atb_api "GetWorkspaceSize failed"); \
432+
CHECK_EQ(workspace_status, 0) \
433+
<< "call " #atb_api "GetWorkspaceSize failed"; \
445434
void* workspace_addr = nullptr; \
446435
at::Tensor workspace_tensor; \
447436
if (workspace_size != 0) { \
@@ -451,7 +440,7 @@ void release_convert_types(Tuple& t) {
451440
} \
452441
AtbApiFunc atbApiFunc = reinterpret_cast<AtbApiFunc>(AtbApiFuncAddr); \
453442
api_ret = atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \
454-
TORCH_CHECK(api_ret == 0, "call " #atb_api " failed"); \
443+
CHECK_EQ(api_ret, 0) << "call " #atb_api " failed"; \
455444
release_convert_types(converted_params); \
456445
return api_ret; \
457446
}; \

xllm/core/kernels/npu/custom_functions_npu/operation_create.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#pragma once
1717

18+
#include <glog/logging.h>
1819
#include <torch_npu/csrc/core/npu/NPUGraphsUtils.h>
1920
#include <torch_npu/csrc/framework/OpCommand.h>
2021

@@ -55,7 +56,7 @@ atb::Operation* create_atb_operation(const ParamType& param,
5556
const std::string& name) {
5657
atb::Operation* op = nullptr;
5758
atb::CreateOperation(param, &op);
58-
TORCH_CHECK(op != nullptr, name, " CreateOperation failed!");
59+
CHECK(op != nullptr) << name << " CreateOperation failed!";
5960
return op;
6061
}
6162

xllm/core/kernels/npu/custom_functions_npu/utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ ContextManager::ContextManager() : atb_context_(nullptr) {}
3030
ContextManager::~ContextManager() {
3131
if (atb_context_) {
3232
auto status = atb::DestroyContext(atb_context_);
33-
TORCH_CHECK(status == 0, "Destroy context failed!");
33+
CHECK_EQ(status, 0) << "Destroy context failed!";
3434
atb_context_ = nullptr;
3535
}
3636
}
3737

3838
atb::Context* ContextManager::get_context(aclrtStream stream) {
3939
std::call_once(create_flag_, [this]() {
4040
auto status = atb::CreateContext(&atb_context_);
41-
TORCH_CHECK(status == 0, "Create context failed!");
41+
CHECK_EQ(status, 0) << "Create context failed!";
4242
});
4343

4444
atb_context_->SetExecuteStream(stream);
@@ -52,8 +52,8 @@ atb::Context* get_context(aclrtStream stream) {
5252
aclDataType convert_to_acl_data_type(const at::ScalarType& data_type) {
5353
auto acl_dtype =
5454
kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(data_type)];
55-
TORCH_CHECK(acl_dtype != ACL_DT_UNDEFINED,
56-
std::string(c10::toString(data_type)) + " has not been supported")
55+
CHECK_NE(acl_dtype, ACL_DT_UNDEFINED)
56+
<< std::string(c10::toString(data_type)) << " has not been supported";
5757
return acl_dtype;
5858
}
5959

xllm/core/kernels/npu/custom_functions_npu/utils.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <ATen/ATen.h>
1919
#include <acl/acl.h>
20+
#include <glog/logging.h>
2021
#include <torch_npu/csrc/core/npu/NPUFormat.h>
2122

2223
#include "atb/atb_infer.h"
@@ -88,12 +89,8 @@ inline int get_op_mode(const MapType& mode_map,
8889
const char* mode_name) {
8990
c10::string_view mode_str = mode_opt.value_or(default_mode);
9091
auto it = mode_map.find(mode_str);
91-
TORCH_CHECK(it != mode_map.end(),
92-
"Unsupported ",
93-
mode_name,
94-
" value: '",
95-
mode_str,
96-
"'");
92+
CHECK(it != mode_map.end())
93+
<< "Unsupported " << mode_name << " value: '" << mode_str << "'";
9794
return it->second;
9895
}
9996
} // namespace utils

xllm/core/kernels/npu/fused_layernorm.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15+
#include <glog/logging.h>
1516
#include <torch_npu/csrc/aten/CustomFunctions.h>
1617

1718
#include "npu_ops_api.h"
1819
#include "ops_npu/npu_ops.h"
1920

2021
namespace xllm::kernel::npu {
2122

22-
torch::Tensor fused_layernorm(const torch::Tensor& input,
23-
const torch::Tensor& weight,
24-
double eps,
25-
const std::string& mode) {
23+
torch::Tensor rms_norm(const torch::Tensor& input,
24+
const torch::Tensor& weight,
25+
double eps,
26+
const std::string& mode) {
2627
if (mode != "rmsnorm") {
27-
throw std::runtime_error(
28-
"Only rmsnorm mode is supported in NPU fused_layernorm");
28+
LOG(FATAL) << "Only rmsnorm mode is supported in NPU rms_norm";
2929
}
3030
std::tuple<at::Tensor, at::Tensor> result =
3131
at_npu::native::custom_ops::npu_rms_norm(input, weight, eps);

xllm/core/kernels/npu/npu_ops_api.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818

1919
#include <optional>
2020

21-
#include "./custom_functions_npu/atb_common.h"
21+
#include "custom_functions_npu/atb_common.h"
2222

2323
namespace xllm::kernel::npu {
2424

@@ -50,10 +50,10 @@ torch::Tensor matmul(const torch::Tensor& a,
5050

5151
torch::Tensor active(const torch::Tensor& input, const std::string& act_mode);
5252

53-
torch::Tensor fused_layernorm(const torch::Tensor& input,
54-
const torch::Tensor& weight,
55-
double eps,
56-
const std::string& mode);
53+
torch::Tensor rms_norm(const torch::Tensor& input,
54+
const torch::Tensor& weight,
55+
double eps,
56+
const std::string& mode);
5757

5858
void apply_rotary(torch::Tensor& q,
5959
torch::Tensor& k,

xllm/core/kernels/npu/ops_npu/npu_ops.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,12 @@ limitations under the License.
1414
==============================================================================*/
1515
#pragma once
1616

17-
#include "../custom_functions_npu/atb_common.h"
17+
#include "kernels/npu/custom_functions_npu/atb_common.h"
1818

1919
using namespace std;
2020

2121
namespace atb {
2222

23-
using PagedAttentionParam = atb::infer::PagedAttentionParam;
24-
using ReshapeAndCacheParam = atb::infer::ReshapeAndCacheParam;
25-
using SelfAttentionParam = atb::infer::SelfAttentionParam;
26-
2723
void _npu_paged_attention(const at::Tensor& query,
2824
const at::Tensor& key_cache,
2925
const at::Tensor& value_cache,

0 commit comments

Comments
 (0)