From 7e5001d7538da206847b8dfb21db8de8dc83fac8 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 24 Dec 2025 12:58:23 -0800 Subject: [PATCH] feat: Restrict `__slots__=()` for subclasses of `tvm_ffi.Object` by default --- docs/conf.py | 7 +- python/tvm_ffi/core.pyi | 6 +- python/tvm_ffi/cython/base.pxi | 1 + python/tvm_ffi/cython/device.pxi | 1 + python/tvm_ffi/cython/dtype.pxi | 1 + python/tvm_ffi/cython/error.pxi | 7 +- python/tvm_ffi/cython/function.pxi | 48 +++---- python/tvm_ffi/cython/object.pxi | 201 +++++++++++++++++----------- python/tvm_ffi/cython/string.pxi | 4 +- python/tvm_ffi/cython/tensor.pxi | 12 +- python/tvm_ffi/cython/type_info.pxi | 8 +- python/tvm_ffi/module.py | 16 ++- tests/python/test_object.py | 32 +++++ 13 files changed, 221 insertions(+), 123 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index a575ee04..3ea87b04 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -282,7 +282,9 @@ def _link_inherited_members(app, what, name, obj, options, lines) -> None: # no # If it comes from builtins we already hide it; no link needed if base in _py_native_classes or getattr(base, "__module__", "") == "builtins": return - owner_fq = f"{base.__module__}.{base.__qualname__}".replace("tvm_ffi.core.", "tvm_ffi.") + owner_fq = f"{base.__module__}.{base.__qualname__}".replace( + "tvm_ffi.core.", "tvm_ffi." + ).replace(".CObject", ".Object") role = "attr" if what in {"attribute", "property"} else "meth" lines.clear() lines.append( @@ -329,6 +331,9 @@ def _import_cls(cls_name: str) -> type | None: "__ffi_init__", "__from_extern_c__", "__from_mlir_packed_safe_call__", + "_move", + "__move_handle_from__", + "__init_handle_by_constructor__", } # If a member method comes from one of these native types, hide it in the docs _py_native_classes: tuple[type, ...] = ( diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index 2cee79cc..8930ac7c 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -32,7 +32,7 @@ _TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType | None], str] | None # DLPack protocol version (defined in tensor.pxi) __dlpack_version__: tuple[int, int] -class Object: +class CObject: def __ctypes_handle__(self) -> Any: ... def __chandle__(self) -> int: ... def __reduce__(self) -> Any: ... @@ -46,7 +46,9 @@ class Object: def __ffi_init__(self, *args: Any) -> None: ... def same_as(self, other: Any) -> bool: ... def _move(self) -> ObjectRValueRef: ... - def __move_handle_from__(self, other: Object) -> None: ... + def __move_handle_from__(self, other: CObject) -> None: ... + +class Object(CObject): ... class ObjectConvertible: def asobject(self) -> Object: ... diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi index 9e7ff877..e4b7a252 100644 --- a/python/tvm_ffi/cython/base.pxi +++ b/python/tvm_ffi/cython/base.pxi @@ -392,6 +392,7 @@ cdef extern from "tvm_ffi_python_helpers.h": cdef class ByteArrayArg: + __slots__ = () cdef TVMFFIByteArray cdata cdef object py_data diff --git a/python/tvm_ffi/cython/device.pxi b/python/tvm_ffi/cython/device.pxi index 2eb36fc9..bf4dce4f 100644 --- a/python/tvm_ffi/cython/device.pxi +++ b/python/tvm_ffi/cython/device.pxi @@ -91,6 +91,7 @@ cdef class Device: assert str(dev) == "cuda:0" """ + __slots__ = () cdef DLDevice cdevice _DEVICE_TYPE_TO_NAME = { diff --git a/python/tvm_ffi/cython/dtype.pxi b/python/tvm_ffi/cython/dtype.pxi index 3b8530a4..5c242faf 100644 --- a/python/tvm_ffi/cython/dtype.pxi +++ b/python/tvm_ffi/cython/dtype.pxi @@ -79,6 +79,7 @@ cdef class DataType: assert str(d) == "int32" """ + __slots__ = () cdef DLDataType cdtype def __init__(self, dtype_str: str) -> None: diff --git a/python/tvm_ffi/cython/error.pxi b/python/tvm_ffi/cython/error.pxi index d85cb873..cf87c5d6 100644 --- a/python/tvm_ffi/cython/error.pxi +++ b/python/tvm_ffi/cython/error.pxi @@ -27,7 +27,7 @@ _WITH_APPEND_BACKTRACE: Optional[Callable[[BaseException, str], BaseException]] _TRACEBACK_TO_BACKTRACE_STR: Optional[Callable[[types.TracebackType | None], str]] = None -cdef class Error(Object): +cdef class Error(CObject): """Base class for FFI errors. An :class:`Error` is a lightweight wrapper around a concrete Python @@ -43,6 +43,7 @@ cdef class Error(Object): Do not directly raise this object. Instead, use :py:meth:`py_error` to convert it to a Python exception and raise that. """ + __slots__ = () def __init__(self, kind: str, message: str, backtrace: str): """Construct an error wrapper. @@ -66,7 +67,7 @@ cdef class Error(Object): ) if ret != 0: raise MemoryError("Failed to create error object") - (self).chandle = out + (self).chandle = out def update_backtrace(self, backtrace: str) -> None: """Replace the stored backtrace string with ``backtrace``. @@ -107,7 +108,7 @@ cdef class Error(Object): cdef inline Error move_from_last_error(): # raise last error error = Error.__new__(Error) - TVMFFIErrorMoveFromRaised(&(error).chandle) + TVMFFIErrorMoveFromRaised(&(error).chandle) return error diff --git a/python/tvm_ffi/cython/function.pxi b/python/tvm_ffi/cython/function.pxi index 01a33667..1b362d6e 100644 --- a/python/tvm_ffi/cython/function.pxi +++ b/python/tvm_ffi/cython/function.pxi @@ -131,7 +131,7 @@ cdef int TVMFFIPyArgSetterTensor_( TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* arg, TVMFFIAny* out ) except -1: - if (arg).chandle != NULL: + if (arg).chandle != NULL: out.type_index = kTVMFFITensor out.v_ptr = (arg).chandle else: @@ -144,8 +144,8 @@ cdef int TVMFFIPyArgSetterObject_( TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx, PyObject* arg, TVMFFIAny* out ) except -1: - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle return 0 @@ -312,7 +312,7 @@ cdef int TVMFFIPyArgSetterFFIObjectProtocol_( """Setter for objects that implement the `__tvm_ffi_object__` protocol.""" cdef object arg = py_arg cdef TVMFFIObjectHandle temp_chandle - cdef Object obj = arg.__tvm_ffi_object__() + cdef CObject obj = arg.__tvm_ffi_object__() cdef long ref_count = Py_REFCNT(obj) temp_chandle = obj.chandle out.type_index = TVMFFIObjectGetTypeIndex(temp_chandle) @@ -418,8 +418,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectStr_( # need to check if the arg is a large string returned from ffi if arg._tvm_ffi_cached_object is not None: arg = arg._tvm_ffi_cached_object - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle return 0 return TVMFFIPyArgSetterStr_(handle, ctx, py_arg, out) @@ -457,8 +457,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectBytes_( # need to check if the arg is a large bytes returned from ffi if arg._tvm_ffi_cached_object is not None: arg = arg._tvm_ffi_cached_object - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle return 0 return TVMFFIPyArgSetterBytes_(handle, ctx, py_arg, out) @@ -473,8 +473,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectGeneral_( raise ValueError(f"_tvm_ffi_cached_object is None for {type(arg)}") assert arg._tvm_ffi_cached_object is not None arg = arg._tvm_ffi_cached_object - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle return 0 @@ -507,7 +507,7 @@ cdef int TVMFFIPyArgSetterObjectRValueRef_( """Setter for ObjectRValueRef""" cdef object arg = py_arg out.type_index = kTVMFFIObjectRValueRef - out.v_ptr = &(((arg.obj)).chandle) + out.v_ptr = &(((arg.obj)).chandle) return 0 @@ -532,8 +532,8 @@ cdef int TVMFFIPyArgSetterException_( """Setter for Exception""" cdef object arg = py_arg arg = _convert_to_ffi_error(arg) - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle TVMFFIPyPushTempPyObject(ctx, arg) return 0 @@ -595,8 +595,8 @@ cdef int TVMFFIPyArgSetterObjectConvertible_( # recursively construct a new map cdef object arg = py_arg arg = arg.asobject() - out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) - out.v_ptr = (arg).chandle + out.type_index = TVMFFIObjectGetTypeIndex((arg).chandle) + out.v_ptr = (arg).chandle TVMFFIPyPushTempPyObject(ctx, arg) @@ -727,7 +727,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce if isinstance(arg, Tensor): out.func = TVMFFIPyArgSetterTensor_ return 0 - if isinstance(arg, Object): + if isinstance(arg, CObject): out.func = TVMFFIPyArgSetterObject_ return 0 if isinstance(arg, ObjectRValueRef): @@ -857,7 +857,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce # --------------------------------------------------------------------------------------------- # Implementation of function calling # --------------------------------------------------------------------------------------------- -cdef class Function(Object): +cdef class Function(CObject): """Callable wrapper around a TVM FFI function. Instances are obtained by converting Python callables with @@ -908,7 +908,7 @@ cdef class Function(Object): result.v_int64 = 0 TVMFFIPyFuncCall( TVMFFIPyArgSetterFactory_, - (self).chandle, args, + (self).chandle, args, &result, &c_api_ret_code, self.release_gil, @@ -972,7 +972,7 @@ cdef class Function(Object): CHECK_CALL(ret_code) func = Function.__new__(Function) - (func).chandle = chandle + (func).chandle = chandle return func @staticmethod @@ -1026,7 +1026,7 @@ cdef class Function(Object): TVMFFIPyMLIRPackedSafeCallDeleter(mlir_packed_safe_call) CHECK_CALL(ret_code) func = Function.__new__(Function) - (func).chandle = chandle + (func).chandle = chandle return func @@ -1039,7 +1039,7 @@ def _register_global_func(name: str, pyfunc: Callable[..., Any] | Function, over if not isinstance(pyfunc, Function): pyfunc = _convert_to_ffi_func(pyfunc) - CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(), (pyfunc).chandle, ioverride)) + CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(), (pyfunc).chandle, ioverride)) return pyfunc @@ -1050,7 +1050,7 @@ def _get_global_func(name: str, allow_missing: bool): CHECK_CALL(TVMFFIFunctionGetGlobal(name_arg.cptr(), &chandle)) if chandle != NULL: ret = Function.__new__(Function) - (ret).chandle = chandle + (ret).chandle = chandle return ret if allow_missing: @@ -1105,7 +1105,7 @@ def _convert_to_ffi_func(object pyfunc: Callable[..., Any]) -> Function: cdef TVMFFIObjectHandle chandle _convert_to_ffi_func_handle(pyfunc, &chandle) ret = Function.__new__(Function) - (ret).chandle = chandle + (ret).chandle = chandle return ret @@ -1127,7 +1127,7 @@ def _convert_to_opaque_object(object pyobject: Any) -> OpaquePyObject: cdef TVMFFIObjectHandle chandle _convert_to_opaque_object_handle(pyobject, &chandle) ret = OpaquePyObject.__new__(OpaquePyObject) - (ret).chandle = chandle + (ret).chandle = chandle return ret diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi index c7a9a47d..96cf7de3 100644 --- a/python/tvm_ffi/cython/object.pxi +++ b/python/tvm_ffi/cython/object.pxi @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import json +from abc import ABCMeta from typing import Any @@ -77,38 +78,13 @@ class ObjectRValueRef: self.obj = obj -cdef class Object: - """Base class of all TVM FFI objects. - - This is the root Python type for objects backed by the TVM FFI - runtime. Each instance references a handle to a C++ runtime - object. Python subclasses typically correspond to C++ runtime - types and are registered via :py:meth:`tvm_ffi.register_object`. - - Notes - ----- - - Equality of two :py:class:`Object` instances uses underlying handle - identity unless an overridden implementation is provided on the - concrete type. Use :py:meth:`same_as` to check whether two - references point to the same underlying object. - - Most users interact with subclasses (e.g. :class:`Tensor`, - :class:`Function`) rather than :py:class:`Object` directly. - - Examples - -------- - Constructing objects is typically performed by Python wrappers that - call into registered constructors on the FFI side. - - .. code-block:: python - - import tvm_ffi.testing - - # Acquire a testing object constructed through FFI - obj = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) - assert isinstance(obj, tvm_ffi.Object) - assert obj.same_as(obj) +cdef class CObject: + """Cython base class for TVM FFI objects. + This extension type owns the low-level handle. Prefer subclassing + :class:`Object` in Python to enforce slots policy. """ + __slots__ = () cdef void* chandle def __cinit__(self): @@ -163,30 +139,90 @@ cdef class Object: def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) -> None: - """Initialize the handle by calling constructor function. + def __hash__(self) -> int: + cdef uint64_t hash_value = self.chandle + return hash_value - Parameters - ---------- - fconstructor : Function - Constructor function. + def same_as(self, other: object) -> bool: + return isinstance(other, CObject) and self.chandle == (other).chandle - args: list of objects - The arguments to the constructor + def __move_handle_from__(self, other: CObject) -> None: + self.chandle = (other).chandle + (other).chandle = NULL - Notes - ----- - We have a special calling convention to call constructor functions. - So the return handle is directly set into the Node object - instead of creating a new Node. - """ + def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) -> None: # avoid error raised during construction. self.chandle = NULL cdef void* chandle ConstructorCall( - (fconstructor).chandle, args, &chandle, NULL) + (fconstructor).chandle, args, &chandle, NULL) self.chandle = chandle + +class _ObjectSlotsMeta(ABCMeta): + def __new__(mcls, name: str, bases: tuple[type, ...], ns: dict[str, Any], **kwargs: Any): + if "__slots__" not in ns: + ns["__slots__"] = () + return super().__new__(mcls, name, bases, ns, **kwargs) + + def __instancecheck__(cls, instance: Any) -> bool: + if isinstance(instance, CObject): + return True + return super().__instancecheck__(instance) + + def __subclasscheck__(cls, subclass: type) -> bool: + try: + if issubclass(subclass, CObject): + return True + except TypeError: + pass + return super().__subclasscheck__(subclass) + + +class Object(CObject, metaclass=_ObjectSlotsMeta): + """Base class of all TVM FFI objects. + + This is the root Python type for objects backed by the TVM FFI + runtime. Each instance references a handle to a C++ runtime + object. Python subclasses typically correspond to C++ runtime + types and are registered via :py:meth:`tvm_ffi.register_object`. + + Notes + ----- + - Equality of two :py:class:`Object` instances uses underlying handle + identity unless an overridden implementation is provided on the + concrete type. Use :py:meth:`same_as` to check whether two + references point to the same underlying object. + - Subclasses that omit ``__slots__`` are treated as ``__slots__ = ()``. + ``__dict__`` entries are forbidden in ``__slots__``. + - Most users interact with subclasses (e.g. :class:`Tensor`, + :class:`Function`) rather than :py:class:`Object` directly. + + Examples + -------- + Constructing objects is typically performed by Python wrappers that + call into registered constructors on the FFI side. + + .. code-block:: python + + import tvm_ffi.testing + + # Acquire a testing object constructed through FFI + obj = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12) + assert isinstance(obj, tvm_ffi.Object) + assert obj.same_as(obj) + + Subclasses can declare explicit slots when needed. + + .. code-block:: python + + @tvm_ffi.register_object("my.MyObject") + class MyObject(tvm_ffi.Object): + __slots__ = () + + """ + __slots__ = () + def __ffi_init__(self, *args: Any) -> None: """Initialize the instance using the ``__ffi_init__`` method registered on C++ side. @@ -225,13 +261,7 @@ cdef class Object: assert not x.same_as(z) """ - if not isinstance(other, Object): - return False - return self.chandle == (other).chandle - - def __hash__(self) -> int: - cdef uint64_t hash_value = self.chandle - return hash_value + return CObject.same_as(self, other) def _move(self) -> ObjectRValueRef: """Create an rvalue reference that transfers ownership. @@ -253,17 +283,35 @@ cdef class Object: """ return ObjectRValueRef(self) - def __move_handle_from__(self, other: Object) -> None: + def __move_handle_from__(self, other: CObject) -> None: """Steal the FFI handle from ``other``. Internal helper used by the runtime to implement move semantics. Users should prefer :py:meth:`_move`. """ - self.chandle = (other).chandle - (other).chandle = NULL + CObject.__move_handle_from__(self, other) + + def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) -> None: + """Initialize the handle by calling constructor function. + + Parameters + ---------- + fconstructor : Function + Constructor function. + + args: list of objects + The arguments to the constructor + + Notes + ----- + We have a special calling convention to call constructor functions. + So the return handle is directly set into the Node object + instead of creating a new Node. + """ + CObject.__init_handle_by_constructor__(self, fconstructor, *args) -cdef class OpaquePyObject(Object): +cdef class OpaquePyObject(CObject): """Wrapper that carries an arbitrary Python object across the FFI. The contained object is held with correct reference counting, and @@ -274,6 +322,8 @@ cdef class OpaquePyObject(Object): ``OpaquePyObject`` is useful when a Python value must traverse the FFI boundary without conversion into a native FFI type. """ + __slots__ = () + def pyobject(self) -> object: """Return the original Python object held by this wrapper.""" cdef object obj @@ -335,7 +385,7 @@ cdef inline str _type_index_to_key(int32_t tindex): cdef inline object make_ret_opaque_object(TVMFFIAny result): obj = OpaquePyObject.__new__(OpaquePyObject) - (obj).chandle = result.v_obj + (obj).chandle = result.v_obj return obj.pyobject() cdef inline object make_fallback_cls_for_type_index(int32_t type_index): @@ -352,30 +402,25 @@ cdef inline object make_fallback_cls_for_type_index(int32_t type_index): # Create `type_info.type_cls` now class cls(parent_type_info.type_cls): - pass - attrs = cls.__dict__.copy() - attrs.pop("__dict__", None) - attrs.pop("__weakref__", None) - attrs.update({ - "__slots__": (), - "__tvm_ffi_type_info__": type_info, - "__name__": type_key.split(".")[-1], - "__qualname__": type_key, - "__module__": ".".join(type_key.split(".")[:-1]), - "__doc__": f"Auto-generated fallback class for {type_key}.\n" - "This class is generated because the class is not registered.\n" - "Please do not use this class directly, instead register the class\n" - "using `register_object` decorator.", - }) + __slots__ = () + + cls.__tvm_ffi_type_info__ = type_info + cls.__name__ = type_key.split(".")[-1] + cls.__qualname__ = type_key + cls.__module__ = ".".join(type_key.split(".")[:-1]) + cls.__doc__ = ( + f"Auto-generated fallback class for {type_key}.\n" + "This class is generated because the class is not registered.\n" + "Please do not use this class directly, instead register the class\n" + "using `register_object` decorator." + ) for field in type_info.fields: - attrs[field.name] = field.as_property(cls) + setattr(cls, field.name, field.as_property(cls)) for method in type_info.methods: name = method.name if name == "__ffi_init__": name = "__c_ffi_init__" - attrs[name] = method.as_callable(cls) - for name, val in attrs.items(): - setattr(cls, name, val) + setattr(cls, name, method.as_callable(cls)) # Update the registry type_info.type_cls = cls _update_registry(type_index, type_key, type_info, cls) @@ -390,7 +435,7 @@ cdef inline object make_ret_object(TVMFFIAny result): if type_index < len(TYPE_INDEX_TO_CLS) and (cls := TYPE_INDEX_TO_CLS[type_index]) is not None: if issubclass(cls, PyNativeObject): obj = Object.__new__(Object) - (obj).chandle = result.v_obj + (obj).chandle = result.v_obj return cls.__from_tvm_ffi_object__(cls, obj) else: # Slow path: object is not found in registered entry @@ -398,7 +443,7 @@ cdef inline object make_ret_object(TVMFFIAny result): # For every unregistered class, this slow path will be triggered only once. cls = make_fallback_cls_for_type_index(type_index) obj = cls.__new__(cls) - (obj).chandle = result.v_obj + (obj).chandle = result.v_obj return obj diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi index 2c23a08f..399c8199 100644 --- a/python/tvm_ffi/cython/string.pxi +++ b/python/tvm_ffi/cython/string.pxi @@ -19,12 +19,12 @@ # helper class for string/bytes handling cdef inline str _string_obj_get_py_str(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) + cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) return bytearray_to_str(bytes) cdef inline bytes _bytes_obj_get_py_bytes(obj): - cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) + cdef TVMFFIByteArray* bytes = TVMFFIBytesGetByteArrayPtr((obj).chandle) return bytearray_to_bytes(bytes) diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi index 8b78c809..dcf8e19c 100644 --- a/python/tvm_ffi/cython/tensor.pxi +++ b/python/tvm_ffi/cython/tensor.pxi @@ -230,8 +230,8 @@ def from_dlpack( # helper class for shape handling -def _shape_obj_get_py_tuple(obj: "Object") -> tuple[int, ...]: - cdef TVMFFIShapeCell* shape = TVMFFIShapeGetCellPtr((obj).chandle) +def _shape_obj_get_py_tuple(obj: "CObject") -> tuple[int, ...]: + cdef TVMFFIShapeCell* shape = TVMFFIShapeGetCellPtr((obj).chandle) return tuple(shape.data[i] for i in range(shape.size)) @@ -247,7 +247,7 @@ def _make_strides_from_shape(tuple shape: tuple[int, ...]) -> tuple[int, ...]: return tuple(reversed(strides)) -cdef class Tensor(Object): +cdef class Tensor(CObject): """Managed n-dimensional array compatible with DLPack. It provides zero-copy interoperability with array libraries @@ -268,6 +268,7 @@ cdef class Tensor(Object): np.testing.assert_equal(np.from_dlpack(x), np.arange(6, dtype="int32")) """ + __slots__ = () cdef DLTensor* cdltensor @property @@ -433,6 +434,7 @@ cdef DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api() noexcept: cdef class DLTensorTestWrapper: """Wrapper of a Tensor that exposes DLPack protocol, only for testing purpose. """ + __slots__ = () __dlpack_c_exchange_api__ = pycapsule.PyCapsule_New( _dltensor_test_wrapper_get_exchange_api(), b"dlpack_exchange_api", @@ -465,7 +467,7 @@ cdef inline object make_ret_dltensor(TVMFFIAny result): cdef DLTensor* dltensor dltensor = result.v_ptr tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR) - (tensor).chandle = NULL + (tensor).chandle = NULL (tensor).cdltensor = dltensor return tensor @@ -497,7 +499,7 @@ cdef inline object make_tensor_from_chandle( pass # default return the tensor tensor = _CLASS_TENSOR.__new__(_CLASS_TENSOR) - (tensor).chandle = chandle + (tensor).chandle = chandle (tensor).cdltensor = TVMFFITensorGetDLTensorPtr(chandle) return tensor diff --git a/python/tvm_ffi/cython/type_info.pxi b/python/tvm_ffi/cython/type_info.pxi index 093d8f72..c15ff1c9 100644 --- a/python/tvm_ffi/cython/type_info.pxi +++ b/python/tvm_ffi/cython/type_info.pxi @@ -25,10 +25,10 @@ cdef class FieldGetter: cdef TVMFFIFieldGetter getter cdef int64_t offset - def __call__(self, Object obj): + def __call__(self, CObject obj): cdef TVMFFIAny result cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset + cdef void* field_ptr = ((obj).chandle) + self.offset result.type_index = kTVMFFINone result.v_int64 = 0 c_api_ret_code = self.getter(field_ptr, &result) @@ -41,9 +41,9 @@ cdef class FieldSetter: cdef TVMFFIFieldSetter setter cdef int64_t offset - def __call__(self, Object obj, value): + def __call__(self, CObject obj, value): cdef int c_api_ret_code - cdef void* field_ptr = ((obj).chandle) + self.offset + cdef void* field_ptr = ((obj).chandle) + self.offset TVMFFIPyCallFieldSetter( TVMFFIPyArgSetterFactory_, self.setter, diff --git a/python/tvm_ffi/module.py b/python/tvm_ffi/module.py index 5aac8f09..3e46f9f2 100644 --- a/python/tvm_ffi/module.py +++ b/python/tvm_ffi/module.py @@ -113,6 +113,7 @@ def run_some_tests(): # tvm-ffi-stubgen(end) entry_name: ClassVar[str] = "main" # constant for entry function name + __slots__ = ("_tvm_ffi_attr_cache",) @property def kind(self) -> str: @@ -161,11 +162,18 @@ def implements_function(self, name: str, query_imports: bool = False) -> bool: def __getattr__(self, name: str) -> core.Function: """Accessor to allow getting functions as attributes.""" try: - func = self.get_function(name) - self.__dict__[name] = func - return func + cache = object.__getattribute__(self, "_tvm_ffi_attr_cache") except AttributeError: - raise AttributeError(f"Module has no function '{name}'") + cache = {} + object.__setattr__(self, "_tvm_ffi_attr_cache", cache) + if name in cache: + return cache[name] + try: + func = self.get_function(name) + except AttributeError as exc: + raise AttributeError(f"Module has no function '{name}'") from exc + cache[name] = func + return func def get_function(self, name: str, query_imports: bool = False) -> core.Function: """Get function from the module. diff --git a/tests/python/test_object.py b/tests/python/test_object.py index e49d3ec5..8d1e64e5 100644 --- a/tests/python/test_object.py +++ b/tests/python/test_object.py @@ -174,6 +174,38 @@ def _check_type(x: Any) -> None: _check_type(obj) +@pytest.mark.parametrize( + ("test_cls", "make_instance"), + [ + ( + tvm_ffi.testing.TestObjectBase, + lambda: tvm_ffi.testing.create_object("testing.TestObjectBase"), + ), + ( + tvm_ffi.testing.TestIntPair, + lambda: tvm_ffi.testing.TestIntPair(1, 2), # type: ignore[call-arg] + ), + ( + tvm_ffi.testing.TestObjectDerived, + lambda: tvm_ffi.testing.create_object( + "testing.TestObjectDerived", + v_i64=20, + v_map=tvm_ffi.convert({"a": 1}), + v_array=tvm_ffi.convert([1, 2]), + ), + ), + ], +) +def test_object_subclass_slots(test_cls: type, make_instance: Any) -> None: + slots = test_cls.__dict__.get("__slots__") + assert slots == () + assert "__dict__" not in test_cls.__dict__ + assert "__weakref__" not in test_cls.__dict__ + obj = make_instance() + with pytest.raises(AttributeError): + obj._tvm_ffi_test_attr = "nope" + + @pytest.mark.parametrize( "test_cls, type_key, parent_cls", [