Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 117 additions & 30 deletions cuda_core/cuda/core/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import threading
import weakref
from collections import namedtuple
from typing import Union
Expand All @@ -17,6 +18,14 @@
)
from cuda.core._utils.cuda_utils import driver, get_binding_version, handle_return, precondition

# Lazy initialization state and synchronization
# For Python 3.13t (free-threaded builds), we use a lock to ensure thread-safe initialization.
# For regular Python builds with GIL, the lock overhead is minimal and the code remains safe.
_init_lock = threading.Lock()
_inited = False
_py_major_ver = None
_driver_ver = None
_kernel_ctypes = None
_backend = {
"old": {
"file": driver.cuModuleLoad,
Expand All @@ -27,35 +36,75 @@
}


# TODO: revisit this treatment for py313t builds
_inited = False
_py_major_ver = None
_driver_ver = None
_kernel_ctypes = None
def _lazy_init():
"""
Initialize module-level state in a thread-safe manner.

This function is thread-safe and suitable for both:
- Regular Python builds (with GIL)
- Python 3.13t free-threaded builds (without GIL)

def _lazy_init():
Uses double-checked locking pattern for performance:
- Fast path: check without lock if already initialized
- Slow path: acquire lock and initialize if needed
"""
global _inited
# Fast path: already initialized (no lock needed for read)
if _inited:
return

global _py_major_ver, _driver_ver, _kernel_ctypes
# binding availability depends on cuda-python version
_py_major_ver, _ = get_binding_version()
if _py_major_ver >= 12:
_backend["new"] = {
"file": driver.cuLibraryLoadFromFile,
"data": driver.cuLibraryLoadData,
"kernel": driver.cuLibraryGetKernel,
"attribute": driver.cuKernelGetAttribute,
}
_kernel_ctypes = (driver.CUfunction, driver.CUkernel)
else:
_kernel_ctypes = (driver.CUfunction,)
_driver_ver = handle_return(driver.cuDriverGetVersion())
if _py_major_ver >= 12 and _driver_ver >= 12040:
_backend["new"]["paraminfo"] = driver.cuKernelGetParamInfo
_inited = True
# Slow path: acquire lock and initialize
with _init_lock:
# Double-check: another thread might have initialized while we waited
if _inited:
return

global _py_major_ver, _driver_ver, _kernel_ctypes
# binding availability depends on cuda-python version
_py_major_ver, _ = get_binding_version()
if _py_major_ver >= 12:
_backend["new"] = {
"file": driver.cuLibraryLoadFromFile,
"data": driver.cuLibraryLoadData,
"kernel": driver.cuLibraryGetKernel,
"attribute": driver.cuKernelGetAttribute,
}
_kernel_ctypes = (driver.CUfunction, driver.CUkernel)
else:
_kernel_ctypes = (driver.CUfunction,)
_driver_ver = handle_return(driver.cuDriverGetVersion())
if _py_major_ver >= 12 and _driver_ver >= 12040:
_backend["new"]["paraminfo"] = driver.cuKernelGetParamInfo

# Mark as initialized (must be last to ensure all state is set)
_inited = True


# Auto-initializing property accessors
def _get_py_major_ver():
"""Get the Python binding major version, initializing if needed."""
_lazy_init()
return _py_major_ver


def _get_driver_ver():
"""Get the CUDA driver version, initializing if needed."""
_lazy_init()
return _driver_ver


def _get_kernel_ctypes():
"""Get the kernel ctypes tuple, initializing if needed."""
_lazy_init()
return _kernel_ctypes


def _get_backend_version():
"""Get the backend version ("new" or "old") based on CUDA version.

Returns "new" for CUDA 12.0+ (uses cuLibrary API), "old" otherwise (uses cuModule API).
"""
return "new" if (_get_py_major_ver() >= 12 and _get_driver_ver() >= 12000) else "old"


class KernelAttributes:
Expand All @@ -70,7 +119,7 @@ def _init(cls, kernel):
self._kernel = weakref.ref(kernel)
self._cache = {}

self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
self._backend_version = _get_backend_version()
self._loader = _backend[self._backend_version]
return self

Expand Down Expand Up @@ -197,7 +246,9 @@ def cluster_scheduling_policy_preference(self, device_id: Device | int = None) -


class KernelOccupancy:
""" """
"""This class offers methods to query occupancy metrics that help determine optimal
launch parameters such as block size, grid size, and shared memory usage.
"""

def __new__(self, *args, **kwargs):
raise RuntimeError("KernelOccupancy cannot be instantiated directly. Please use Kernel APIs.")
Expand Down Expand Up @@ -378,7 +429,7 @@ def __new__(self, *args, **kwargs):

@classmethod
def _from_obj(cls, obj, mod):
assert_type(obj, _kernel_ctypes)
assert_type(obj, _get_kernel_ctypes())
assert_type(mod, ObjectCode)
ker = super().__new__(cls)
ker._handle = obj
Expand All @@ -399,9 +450,10 @@ def _get_arguments_info(self, param_info=False) -> tuple[int, list[ParamInfo]]:
if attr_impl._backend_version != "new":
raise NotImplementedError("New backend is required")
if "paraminfo" not in attr_impl._loader:
driver_ver = _get_driver_ver()
raise NotImplementedError(
"Driver version 12.4 or newer is required for this function. "
f"Using driver version {_driver_ver // 1000}.{(_driver_ver % 1000) // 10}"
f"Using driver version {driver_ver // 1000}.{(driver_ver % 1000) // 10}"
)
arg_pos = 0
param_info_data = []
Expand Down Expand Up @@ -436,7 +488,43 @@ def occupancy(self) -> KernelOccupancy:
self._occupancy = KernelOccupancy._init(self._handle)
return self._occupancy

# TODO: implement from_handle()
@staticmethod
def from_handle(handle: int, mod: "ObjectCode" = None) -> "Kernel":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would open a can of worms and we should be careful. Mainly because

  1. we cover two sets of driver APIs (function/module vs kernel/library)
  2. function/kernel are mostly interchangeable, but their parents (module/library) are not IIRC
  3. bidirectional lookup cannot be done between function and module, but can be between kernel and library

We should discuss this in depth after the holidays.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy New Year! 🎄

"""Creates a new :obj:`Kernel` object from a foreign kernel handle.

Uses a CUfunction or CUkernel pointer address to create a new :obj:`Kernel` object.

Parameters
----------
handle : int
Kernel handle representing the address of a foreign
kernel object (CUfunction or CUkernel).
mod : :obj:`ObjectCode`, optional
The ObjectCode object associated with this kernel. If not provided,
a placeholder ObjectCode will be created. Note that without a proper
ObjectCode, certain operations may be limited.
"""

# Validate that handle is an integer
if not isinstance(handle, int):
raise TypeError(f"handle must be an integer, got {type(handle).__name__}")

# Convert the integer handle to the appropriate driver type
if _get_py_major_ver() >= 12 and _get_driver_ver() >= 12000:
# Try CUkernel first for newer CUDA versions
kernel_obj = driver.CUkernel(handle)
else:
# Use CUfunction for older versions
kernel_obj = driver.CUfunction(handle)

# If no module provided, create a placeholder
if mod is None:
# Create a placeholder ObjectCode that won't try to load anything
mod = ObjectCode._init(b"", "cubin")
# Set a dummy handle to prevent lazy loading
mod._handle = 1 # Non-null placeholder

return Kernel._from_obj(kernel_obj, mod)


CodeTypeT = Union[bytes, bytearray, str]
Expand Down Expand Up @@ -474,12 +562,11 @@ def __new__(self, *args, **kwargs):
def _init(cls, module, code_type, *, name: str = "", symbol_mapping: dict | None = None):
self = super().__new__(cls)
assert code_type in self._supported_code_type, f"{code_type=} is not supported"
_lazy_init()

# handle is assigned during _lazy_load
self._handle = None

self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
self._backend_version = _get_backend_version()
self._loader = _backend[self._backend_version]

self._code_type = code_type
Expand Down
87 changes: 87 additions & 0 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,90 @@ def test_module_serialization_roundtrip(get_saxpy_kernel_cubin):
assert objcode.code == result.code
assert objcode._sym_map == result._sym_map
assert objcode.code_type == result.code_type


def test_kernel_from_handle(get_saxpy_kernel_cubin):
"""Test Kernel.from_handle() with a valid handle"""
original_kernel, objcode = get_saxpy_kernel_cubin

# Get the handle from the original kernel
handle = int(original_kernel._handle)

# Create a new Kernel from the handle
kernel_from_handle = cuda.core._module.Kernel.from_handle(handle, objcode)
assert isinstance(kernel_from_handle, cuda.core._module.Kernel)

# Verify we can access kernel attributes
max_threads = kernel_from_handle.attributes.max_threads_per_block()
assert isinstance(max_threads, int)
assert max_threads > 0


def test_kernel_from_handle_no_module(get_saxpy_kernel_cubin):
"""Test Kernel.from_handle() without providing a module"""
original_kernel, _ = get_saxpy_kernel_cubin

# Get the handle from the original kernel
handle = int(original_kernel._handle)

# Create a new Kernel from the handle without a module
kernel_from_handle = cuda.core._module.Kernel.from_handle(handle)
assert isinstance(kernel_from_handle, cuda.core._module.Kernel)

# Verify we can still access kernel attributes
max_threads = kernel_from_handle.attributes.max_threads_per_block()
assert isinstance(max_threads, int)
assert max_threads > 0


@pytest.mark.parametrize(
"invalid_value",
[
pytest.param("not_an_int", id="str"),
pytest.param(2.71828, id="float"),
pytest.param(None, id="None"),
pytest.param({"handle": 123}, id="dict"),
pytest.param([456], id="list"),
pytest.param((789,), id="tuple"),
pytest.param(3 + 4j, id="complex"),
pytest.param(b"\xde\xad\xbe\xef", id="bytes"),
pytest.param({999}, id="set"),
pytest.param(object(), id="object"),
],
)
def test_kernel_from_handle_type_validation(invalid_value):
"""Test Kernel.from_handle() with wrong handle types"""
with pytest.raises(TypeError):
cuda.core._module.Kernel.from_handle(invalid_value)


def test_kernel_from_handle_invalid_module_type(get_saxpy_kernel_cubin):
"""Test Kernel.from_handle() with invalid module parameter"""
original_kernel, _ = get_saxpy_kernel_cubin
handle = int(original_kernel._handle)

# Invalid module type (should fail type assertion in _from_obj)
with pytest.raises((TypeError, AssertionError)):
cuda.core._module.Kernel.from_handle(handle, mod="not_an_objectcode")

with pytest.raises((TypeError, AssertionError)):
cuda.core._module.Kernel.from_handle(handle, mod=12345)


def test_kernel_from_handle_multiple_instances(get_saxpy_kernel_cubin):
"""Test creating multiple Kernel instances from the same handle"""
original_kernel, objcode = get_saxpy_kernel_cubin
handle = int(original_kernel._handle)

# Create multiple Kernel instances from the same handle
kernel1 = cuda.core._module.Kernel.from_handle(handle, objcode)
kernel2 = cuda.core._module.Kernel.from_handle(handle, objcode)
kernel3 = cuda.core._module.Kernel.from_handle(handle, objcode)

# All should be valid Kernel objects
assert isinstance(kernel1, cuda.core._module.Kernel)
assert isinstance(kernel2, cuda.core._module.Kernel)
assert isinstance(kernel3, cuda.core._module.Kernel)

# All should reference the same underlying CUDA kernel handle
assert int(kernel1._handle) == int(kernel2._handle) == int(kernel3._handle) == handle
Loading