@@ -16,6 +16,7 @@ limitations under the License.
1616#pragma once
1717
1818#include < dlfcn.h>
19+ #include < glog/logging.h>
1920#include < torch/library.h>
2021#include < torch_npu/csrc/core/npu/NPUStream.h>
2122#include < torch_npu/csrc/core/npu/NPUWorkspaceAllocator.h>
@@ -30,7 +31,7 @@ namespace atb {
3031
3132using aclTensor = struct aclTensor ;
3233constexpr int64_t MAX_DIM_NUM = 5 ;
33- const int N = 32 ;
34+ const int64_t N = 32 ;
3435
3536using _aclCreateTensor = aclTensor* (*)(const int64_t * view_dims,
3637 uint64_t view_dims_num,
@@ -87,7 +88,7 @@ inline void* get_api_func_addr(const char* api_name) {
8788 if (func_addr != nullptr ) {
8889 return func_addr;
8990 }
90- TORCH_CHECK ( false , " get_api_func_addr not found " , api_name) ;
91+ LOG (FATAL) << " get_api_func_addr not found " << api_name;
9192 }
9293}
9394
@@ -119,8 +120,8 @@ inline aclTensor* convert_type(TensorMaintainer& maintainer,
119120 c10::SmallVector<int64_t , MAX_DIM_NUM> storageDims;
120121 // if acl_data_type is ACL_STRING, storageDims is empty.
121122 if (acl_data_type != ACL_STRING) {
122- TORCH_CHECK (at_tensor.itemsize () > 0 ,
123- " the itemsize of tensor must be greater than 0." ) ;
123+ CHECK_GT (at_tensor.itemsize (), 0 )
124+ << " the itemsize of tensor must be greater than 0." ;
124125 storageDims.push_back (at_tensor.storage ().nbytes () / at_tensor.itemsize ());
125126 }
126127
@@ -245,8 +246,8 @@ inline aclTensor* convert_type_v2(TensorStructPtr at_tensor) {
245246 atb::utils::convert_to_acl_data_type (scalar_data_type);
246247 c10::SmallVector<int64_t , MAX_DIM_NUM> storageDims;
247248 if (acl_data_type != ACL_STRING) {
248- TORCH_CHECK ((*at_tensor).itemsize > 0 ,
249- " the itemsize of tensor must be greater than 0." ) ;
249+ CHECK_GT ((*at_tensor).itemsize , 0 )
250+ << " the itemsize of tensor must be greater than 0." ;
250251 storageDims.push_back ((*at_tensor).nbytes / (*at_tensor).itemsize );
251252 }
252253
@@ -349,16 +350,10 @@ void release_convert_types(Tuple& t) {
349350 static const auto getWorkspaceSizeFuncAddr = \
350351 get_api_func_addr (#atb_api " GetWorkspaceSize" ); \
351352 static const auto atbApiFuncAddr = get_api_func_addr (#atb_api); \
352- TORCH_CHECK ( \
353- getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr , \
354- #atb_api, \
355- " or " , \
356- #atb_api " GetWorkspaceSize" , \
357- " not in " , \
358- get_atb_api_lib_name (), \
359- " , or " , \
360- get_atb_api_lib_name (), \
361- " not found." ); \
353+ CHECK (getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr ) \
354+ << #atb_api << " or " << #atb_api " GetWorkspaceSize" << " not in " \
355+ << get_atb_api_lib_name () << " , or " << get_atb_api_lib_name () \
356+ << " not found." ; \
362357 auto acl_stream = c10_npu::getCurrentNPUStream ().stream (false ); \
363358 auto context_ptr = atb::utils::get_context (acl_stream); \
364359 uint64_t workspace_size = 0 ; \
@@ -374,7 +369,7 @@ void release_convert_types(Tuple& t) {
374369 static auto getWorkspaceSizeFunc = \
375370 convert_to_op_api_func (converted_params, getWorkspaceSizeFuncAddr); \
376371 auto workspace_status = call (getWorkspaceSizeFunc, converted_params); \
377- TORCH_CHECK (workspace_status == 0 , " call " #atb_api " failed, detail:" ); \
372+ CHECK_EQ (workspace_status, 0 ) << " call " #atb_api " failed, detail:" ; \
378373 void * workspace_addr = nullptr ; \
379374 at::Tensor workspace_tensor; \
380375 if (workspace_size != 0 ) { \
@@ -395,7 +390,7 @@ void release_convert_types(Tuple& t) {
395390 AtbApiFunc atbApiFunc = reinterpret_cast <AtbApiFunc>(atbApiFuncAddr); \
396391 auto api_ret = \
397392 atbApiFunc (workspace_addr, workspace_size, op, context_ptr); \
398- TORCH_CHECK (api_ret == 0 , " call " #atb_api " failed, detail:" ); \
393+ CHECK_EQ (api_ret, 0 ) << " call " #atb_api " failed, detail:" ; \
399394 DestroyOperation (op); \
400395 release_convert_types (converted_params); \
401396 return api_ret; \
@@ -408,16 +403,10 @@ void release_convert_types(Tuple& t) {
408403 static const auto getWorkspaceSizeFuncAddr = \
409404 get_api_func_addr (#atb_api " GetWorkspaceSize" ); \
410405 static const auto AtbApiFuncAddr = get_api_func_addr (#atb_api); \
411- TORCH_CHECK ( \
412- getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr , \
413- #atb_api, \
414- " or " , \
415- #atb_api " GetWorkspaceSize" , \
416- " not in " , \
417- get_atb_api_lib_name (), \
418- " , or " , \
419- get_atb_api_lib_name (), \
420- " not found." ); \
406+ CHECK (getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr ) \
407+ << #atb_api << " or " << #atb_api " GetWorkspaceSize" << " not in " \
408+ << get_atb_api_lib_name () << " , or " << get_atb_api_lib_name () \
409+ << " not found." ; \
421410 auto acl_stream = c10_npu::getCurrentNPUStream ().stream (false ); \
422411 TensorMaintainer tensor_maintainer; \
423412 auto copied_params = copy_types_v2 (tensor_maintainer, __VA_ARGS__); \
@@ -440,8 +429,8 @@ void release_convert_types(Tuple& t) {
440429 convert_to_op_api_func (converted_params, getWorkspaceSizeFuncAddr); \
441430 auto workspace_status = call (getWorkspaceSizeFunc, converted_params); \
442431 opParamCache.save_operation (hash_id, op); \
443- TORCH_CHECK (workspace_status == 0 , \
444- " call " #atb_api " GetWorkspaceSize failed" ); \
432+ CHECK_EQ (workspace_status, 0 ) \
433+ << " call " #atb_api " GetWorkspaceSize failed" ; \
445434 void * workspace_addr = nullptr ; \
446435 at::Tensor workspace_tensor; \
447436 if (workspace_size != 0 ) { \
@@ -451,7 +440,7 @@ void release_convert_types(Tuple& t) {
451440 } \
452441 AtbApiFunc atbApiFunc = reinterpret_cast <AtbApiFunc>(AtbApiFuncAddr); \
453442 api_ret = atbApiFunc (workspace_addr, workspace_size, op, context_ptr); \
454- TORCH_CHECK (api_ret == 0 , " call " #atb_api " failed" ); \
443+ CHECK_EQ (api_ret, 0 ) << " call " #atb_api " failed" ; \
455444 release_convert_types (converted_params); \
456445 return api_ret; \
457446 }; \
0 commit comments