Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ test = [
"pytest-rerunfailures~=16.0",
"pytest~=9.0",
"requests~=2.0",
"sniffio>=1.3",
"testcontainers~=4.0",
]
types = [
Expand Down
195 changes: 187 additions & 8 deletions src/pact/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,37 @@

from __future__ import annotations

import asyncio
import contextvars
import dis
import inspect
import logging
import socket
import warnings
from contextlib import closing
from functools import partial
from inspect import Parameter, _ParameterKind
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from collections.abc import Callable, Coroutine, Mapping

try:
import sniffio # type: ignore[import-not-found]
except ImportError:
sniffio = None # type: ignore[assignment]

try:
import trio # type: ignore[import-not-found]
from trio.lowlevel import current_trio_token # type: ignore[import-not-found]
except ImportError:
trio = None # type: ignore[assignment]
current_trio_token = None # type: ignore[assignment]

try:
import curio # type: ignore[import-not-found,import-untyped]
except (ImportError, AttributeError):
curio = None # type: ignore[assignment]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,7 +199,7 @@ def find_free_port() -> int:
return s.getsockname()[1]


def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
"""
Apply arguments to a function.
Expand All @@ -188,6 +208,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
it is possible to pass arguments by name, and falling back to positional
arguments if not.
This function supports both synchronous and asynchronous callables. If the
callable is a coroutine function, it will be executed in an event loop.
Args:
f:
The function to apply the arguments to.
Expand All @@ -200,6 +223,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
Returns:
The result of the function.
"""
# Check if f is a partial wrapping an async function
func_to_check = f.func if isinstance(f, partial) else f
is_async = inspect.iscoroutinefunction(func_to_check)
signature = inspect.signature(f)
f_name = (
f.__qualname__
Expand All @@ -226,7 +252,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
# First, we inspect the keyword arguments and try and pass in some arguments
# by currying them in.
for param in signature.parameters.values():
if param.name not in args:
# Try matching the parameter name, or if it starts with underscore,
# also try matching without the leading underscore.
arg_key = None
if param.name in args:
arg_key = param.name
elif (
param.name.startswith("_")
and len(param.name) > 1
and param.name[1:] in args
):
arg_key = param.name[1:]

if arg_key is None:
# If a parameter is not known, we will ignore it.
#
# If the ignored parameter doesn't have a default value, it will
Expand All @@ -246,12 +284,13 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
if param.kind in positional_match:
# We iterate through the parameters in order that they are defined,
# making it fine to pass them in by position one at a time.
f = partial(f, args[param.name])
del args[param.name]
f = partial(f, args[arg_key])
del args[arg_key]
continue

if param.kind in keyword_match:
f = partial(f, **{param.name: args[param.name]})
del args[param.name]
f = partial(f, **{param.name: args[arg_key]})
del args[arg_key]
continue

# At this stage, we have checked all arguments. If we have any arguments
Expand Down Expand Up @@ -282,7 +321,147 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
)

try:
if is_async:
result = f()
if inspect.iscoroutine(result):
return _run_async_coroutine(result)
return result
return f()
except Exception:
logger.exception("Error occurred while calling function %s", f_name)
raise


def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
"""
Run a coroutine in an event loop.
Detects the async runtime (asyncio, trio, or curio) and executes the
coroutine appropriately. Preserves ContextVars when creating a new event
loop, which is important when handlers are called from threads.
Args:
coro:
The coroutine to run.
Returns:
The result of the coroutine.
Raises:
RuntimeError:
If the detected runtime (trio or curio) is not installed.
"""
runtime = _detect_async_runtime_from_coroutine(coro)

if runtime == "trio":
if trio is None:
msg = "trio is not installed"
raise RuntimeError(msg)

if current_trio_token is not None:
try:
token = current_trio_token()

async def _run_with_token() -> _T:
return await coro

return trio.from_thread.run_sync(_run_with_token, trio_token=token) # type: ignore[return-value]
except RuntimeError:
pass

ctx = contextvars.copy_context()

async def _run_trio() -> _T:
return await coro

return ctx.run(trio.run, _run_trio)

if runtime == "curio":
if curio is None:
msg = "curio is not installed"
raise RuntimeError(msg)

try:
return curio.AWAIT(coro)
except RuntimeError:
ctx = contextvars.copy_context()

async def _run_curio() -> _T:
return await coro

return ctx.run(curio.run, _run_curio)

try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop is not None:
future: asyncio.Future[_T] = asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type,assignment]
return future.result()

ctx = contextvars.copy_context()
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]


def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str: # noqa: C901
"""
Detect async runtime by inspecting the coroutine object.
Args:
coro:
The coroutine object to inspect.
Returns:
The detected runtime: "asyncio", "trio", or "curio".
"""
if sniffio is not None:
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass

# Inspect bytecode to check for qualified attribute access (e.g., trio.sleep)
# This is more robust than just checking co_names for module and method separately
func_code = coro.cr_code # type: ignore[attr-defined]

# Parse bytecode to find LOAD_GLOBAL/LOAD_NAME followed by LOAD_ATTR patterns
# This detects qualified accesses like `trio.sleep()` or `curio.spawn()`
bytecode = list(dis.get_instructions(func_code))

trio_detected = False
curio_detected = False

for i, instr in enumerate(bytecode):
# Check for module.attribute pattern (LOAD_GLOBAL/LOAD_NAME + LOAD_ATTR)
if instr.opname in ("LOAD_GLOBAL", "LOAD_NAME") and i + 1 < len(bytecode):
next_instr = bytecode[i + 1]
if next_instr.opname == "LOAD_ATTR":
module_name = instr.argval
attr_name = next_instr.argval

# Check for trio-specific qualified access
if module_name == "trio":
trio_indicators = {
"sleep",
"open_nursery",
"CancelScope",
"current_trio_token",
}
if attr_name in trio_indicators:
trio_detected = True

# Check for curio-specific qualified access
elif module_name == "curio":
curio_indicators = {"sleep", "spawn", "TaskGroup", "AWAIT"}
if attr_name in curio_indicators:
curio_detected = True

# Trio takes precedence if both are detected
if trio_detected:
return "trio"
if curio_detected:
return "curio"

# Default to asyncio as it's the most common
return "asyncio"
Loading
Loading