Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions mypyc/lib-rt/function_wrapper.c
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ static PyObject* CPyFunction_Vectorcall(PyObject *func, PyObject *const *args, s
}


// Steals ml, name, and code. Borrows module.
static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject* name,
PyObject *module, PyObject* code, bool set_self) {
PyCFunctionObject *cf = (PyCFunctionObject *)op;
Expand All @@ -206,12 +207,10 @@ static CPyFunction* CPyFunction_Init(CPyFunction *op, PyMethodDef *ml, PyObject*
Py_XINCREF(module);
cf->m_module = module;

Py_INCREF(name);
op->func_name = name;

((PyCMethodObject *)op)->mm_class = NULL;

Py_XINCREF(code);
op->func_code = code;

CPyFunction_func_vectorcall(op) = CPyFunction_Vectorcall;
Expand Down Expand Up @@ -243,7 +242,7 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu
PyCFunction func, int func_flags, const char *func_doc,
int first_line, int code_flags, bool has_self_arg) {
PyMethodDef *method = NULL;
PyObject *code = NULL, *op = NULL;
PyObject *code = NULL, *name = NULL, *op = NULL;
bool set_self = false;

#ifdef Py_GIL_DISABLED
Expand Down Expand Up @@ -280,18 +279,24 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu
if (unlikely(!code)) {
goto err;
}
name = PyUnicode_FromString(funcname);
if (unlikely(!name)) {
goto err;
}

// Set m_self inside the function wrapper only if the wrapped function has no self arg
// to pass m_self as the self arg when the function is called.
// When the function has a self arg, it will come in the args vector passed to the
// vectorcall handler.
set_self = !has_self_arg;
op = (PyObject *)CPyFunction_Init(PyObject_GC_New(CPyFunction, CPyFunctionType),
method, PyUnicode_FromString(funcname), module,
code, set_self);
if (unlikely(!op)) {
CPyFunction *raw = PyObject_GC_New(CPyFunction, CPyFunctionType);
if (unlikely(!raw)) {
goto err;
}
op = (PyObject *)CPyFunction_Init(raw, method, name, module, code, set_self);
method = NULL;
name = NULL;
code = NULL;
PyObject_GC_Track(op);
return op;

Expand All @@ -300,5 +305,7 @@ PyObject* CPyFunction_New(PyObject *module, const char *filename, const char *fu
if (method) {
PyMem_Free(method);
}
Py_XDECREF(name);
Py_XDECREF(code);
return NULL;
}
34 changes: 34 additions & 0 deletions mypyc/test-data/run-async.test
Original file line number Diff line number Diff line change
Expand Up @@ -1416,13 +1416,16 @@ def run(x: object) -> object: ...

[case testAsyncIntrospection]
import asyncio
import gc
import inspect
import sys
import weakref

from functools import wraps
from typing import Any, Callable, TypeVar, cast

from testutil import is_gil_disabled

def identity(val: int) -> int:
return val

Expand Down Expand Up @@ -1589,6 +1592,37 @@ def test_nested() -> None:
assert is_coroutine(nested_wrapped_async)
assert asyncio.run(nested_wrapped_async()) == 4

def test_async_function_wrapper_code_refcount() -> None:
if is_gil_disabled():
# On free-threaded builds the code object might be immortal, so the ref count test doesn't work.
return
code = getattr(identity_async, "__code__")
getrefcount = getattr(sys, "getrefcount")
# getrefcount sees the local code variable plus the wrapper-owned reference.
assert getrefcount(code) == 2, getrefcount(code)

def test_nested_async_function_wrapper_code_refcount() -> None:
if is_gil_disabled():
# On free-threaded builds the code object might be immortal, so the ref count test doesn't work.
return
def make_nested() -> Any:
async def nested_refcounted() -> int:
return 1

return nested_refcounted

getrefcount = getattr(sys, "getrefcount")
fn = make_nested()
code = getattr(fn, "__code__")
before = getrefcount(code)
assert asyncio.run(fn()) == 1

del fn
gc.collect()
after = getrefcount(code)
assert before == after + 1, (before, after)
assert after == 1, after

[file asyncio/__init__.pyi]
def run(x: object) -> object: ...

Expand Down
Loading