diff --git a/mypyc/lib-rt/function_wrapper.c b/mypyc/lib-rt/function_wrapper.c index 16b1b5930348..dfea6aed1d1c 100644 --- a/mypyc/lib-rt/function_wrapper.c +++ b/mypyc/lib-rt/function_wrapper.c @@ -196,6 +196,7 @@ static PyObject* CPyFunction_Vectorcall(PyObject *func, PyObject *const *args, s } +// Steals ml, name, and code. Borrows module. static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject* name, PyObject *module, PyObject* code, bool set_self) { PyCFunctionObject *cf = (PyCFunctionObject *)op; @@ -206,12 +207,10 @@ static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject* Py_XINCREF(module); cf->m_module = module; - Py_INCREF(name); op->func_name = name; ((PyCMethodObject *)op)->mm_class = NULL; - Py_XINCREF(code); op->func_code = code; CPyFunction_func_vectorcall(op) = CPyFunction_Vectorcall; @@ -243,7 +242,7 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu PyCFunction func, int func_flags, const char *func_doc, int first_line, int code_flags, bool has_self_arg) { PyMethodDef *method = NULL; - PyObject *code = NULL, *op = NULL; + PyObject *code = NULL, *name = NULL, *op = NULL; bool set_self = false; #ifdef Py_GIL_DISABLED @@ -280,18 +279,24 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu if (unlikely(!code)) { goto err; } + name = PyUnicode_FromString(funcname); + if (unlikely(!name)) { + goto err; + } // Set m_self inside the function wrapper only if the wrapped function has no self arg // to pass m_self as the self arg when the function is called. // When the function has a self arg, it will come in the args vector passed to the // vectorcall handler. set_self = !has_self_arg; - op = (PyObject *)CPyFunction_Init(PyObject_GC_New(CPyFunction, CPyFunctionType), - method, PyUnicode_FromString(funcname), module, - code, set_self); - if (unlikely(!op)) { + CPyFunction *raw = PyObject_GC_New(CPyFunction, CPyFunctionType); + if (unlikely(!raw)) { goto err; } + op = (PyObject *)CPyFunction_Init(raw, method, name, module, code, set_self); + method = NULL; + name = NULL; + code = NULL; PyObject_GC_Track(op); return op; @@ -300,5 +305,7 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu if (method) { PyMem_Free(method); } + Py_XDECREF(name); + Py_XDECREF(code); return NULL; } diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 2733a31f3af2..cf9b54368c27 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -1416,6 +1416,7 @@ def run(x: object) -> object: ... [case testAsyncIntrospection] import asyncio +import gc import inspect import sys import weakref @@ -1423,6 +1424,8 @@ import weakref from functools import wraps from typing import Any, Callable, TypeVar, cast +from testutil import is_gil_disabled + def identity(val: int) -> int: return val @@ -1589,6 +1592,37 @@ def test_nested() -> None: assert is_coroutine(nested_wrapped_async) assert asyncio.run(nested_wrapped_async()) == 4 +def test_async_function_wrapper_code_refcount() -> None: + if is_gil_disabled(): + # On free-threaded builds the code object might be immortal, so the ref count test doesn't work. + return + code = getattr(identity_async, "__code__") + getrefcount = getattr(sys, "getrefcount") + # getrefcount sees the local code variable plus the wrapper-owned reference. + assert getrefcount(code) == 2, getrefcount(code) + +def test_nested_async_function_wrapper_code_refcount() -> None: + if is_gil_disabled(): + # On free-threaded builds the code object might be immortal, so the ref count test doesn't work. + return + def make_nested() -> Any: + async def nested_refcounted() -> int: + return 1 + + return nested_refcounted + + getrefcount = getattr(sys, "getrefcount") + fn = make_nested() + code = getattr(fn, "__code__") + before = getrefcount(code) + assert asyncio.run(fn()) == 1 + + del fn + gc.collect() + after = getrefcount(code) + assert before == after + 1, (before, after) + assert after == 1, after + [file asyncio/__init__.pyi] def run(x: object) -> object: ...