Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions keopscore/keopscore/config/cuda.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
import os
import ctypes
from ctypes.util import find_library
import os
import platform
import shutil
import subprocess
import sys
import tempfile
from ctypes import (
c_int,
c_void_p,
c_char_p,
CDLL,
byref,
cast,
POINTER,
Structure,
RTLD_GLOBAL,
Structure,
byref,
c_char_p,
c_int,
c_void_p,
cast,
)
from pathlib import Path
import shutil
from ctypes.util import find_library
from os.path import join
import platform
import tempfile
import subprocess
import sys
from pathlib import Path

import keopscore
from keopscore.utils.misc_utils import KeOps_Warning
from keopscore.utils.misc_utils import KeOps_OS_Run
from keopscore.utils.misc_utils import CHECK_MARK, CROSS_MARK
from keopscore.utils.misc_utils import (
CHECK_MARK,
CROSS_MARK,
KeOps_OS_Run,
KeOps_Warning,
)


class CUDAConfig:
Expand Down Expand Up @@ -300,12 +304,18 @@ def get_cuda_include_path(self):
# Check if CUDA is installed via conda
conda_prefix = os.getenv("CONDA_PREFIX")
if conda_prefix:
include_path = Path(conda_prefix) / "include"
if (include_path / "cuda.h").is_file() and (
include_path / "nvrtc.h"
).is_file():
self.cuda_include_path = str(include_path)
return self.cuda_include_path
for arch in [
"",
"targets/x86_64-linux",
"targets/ppc64le-linux",
"targets/sbsa-linux",
]:
include_path = Path(conda_prefix) / arch / "include"
if (include_path / "cuda.h").is_file() and (
include_path / "nvrtc.h"
).is_file():
self.cuda_include_path = str(include_path)
return self.cuda_include_path

# Check standard locations
cuda_version_str = self.get_cuda_version(out_type="string")
Expand Down