diff --git a/keopscore/keopscore/config/cuda.py b/keopscore/keopscore/config/cuda.py index a52e2b5b1..dd6bc62ac 100644 --- a/keopscore/keopscore/config/cuda.py +++ b/keopscore/keopscore/config/cuda.py @@ -1,28 +1,32 @@ -import os import ctypes -from ctypes.util import find_library +import os +import platform +import shutil +import subprocess +import sys +import tempfile from ctypes import ( - c_int, - c_void_p, - c_char_p, CDLL, - byref, - cast, POINTER, - Structure, RTLD_GLOBAL, + Structure, + byref, + c_char_p, + c_int, + c_void_p, + cast, ) -from pathlib import Path -import shutil +from ctypes.util import find_library from os.path import join -import platform -import tempfile -import subprocess -import sys +from pathlib import Path + import keopscore -from keopscore.utils.misc_utils import KeOps_Warning -from keopscore.utils.misc_utils import KeOps_OS_Run -from keopscore.utils.misc_utils import CHECK_MARK, CROSS_MARK +from keopscore.utils.misc_utils import ( + CHECK_MARK, + CROSS_MARK, + KeOps_OS_Run, + KeOps_Warning, +) class CUDAConfig: @@ -300,12 +304,18 @@ def get_cuda_include_path(self): # Check if CUDA is installed via conda conda_prefix = os.getenv("CONDA_PREFIX") if conda_prefix: - include_path = Path(conda_prefix) / "include" - if (include_path / "cuda.h").is_file() and ( - include_path / "nvrtc.h" - ).is_file(): - self.cuda_include_path = str(include_path) - return self.cuda_include_path + for arch in [ + "", + "targets/x86_64-linux", + "targets/ppc64le-linux", + "targets/sbsa-linux", + ]: + include_path = Path(conda_prefix) / arch / "include" + if (include_path / "cuda.h").is_file() and ( + include_path / "nvrtc.h" + ).is_file(): + self.cuda_include_path = str(include_path) + return self.cuda_include_path # Check standard locations cuda_version_str = self.get_cuda_version(out_type="string")