Skip to content

Commit 33f1acb

Browse files
Enable async state/message handlers (asyncio/trio/curio) with ContextVar preservation and tests
1 parent 02562fe commit 33f1acb

File tree

3 files changed

+1162
-8
lines changed

3 files changed

+1162
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ test = [
120120
"pytest-rerunfailures~=16.0",
121121
"pytest~=9.0",
122122
"requests~=2.0",
123+
"sniffio>=1.3",
123124
"testcontainers~=4.0",
124125
]
125126
types = [

src/pact/_util.py

Lines changed: 187 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,37 @@
99

1010
from __future__ import annotations
1111

12+
import asyncio
13+
import contextvars
14+
import dis
1215
import inspect
1316
import logging
1417
import socket
1518
import warnings
1619
from contextlib import closing
1720
from functools import partial
1821
from inspect import Parameter, _ParameterKind
19-
from typing import TYPE_CHECKING, TypeVar
22+
from typing import TYPE_CHECKING, Any, TypeVar
2023

2124
if TYPE_CHECKING:
22-
from collections.abc import Callable, Mapping
25+
from collections.abc import Callable, Coroutine, Mapping
26+
27+
try:
28+
import sniffio # type: ignore[import-not-found]
29+
except ImportError:
30+
sniffio = None # type: ignore[assignment]
31+
32+
try:
33+
import trio # type: ignore[import-not-found]
34+
from trio.lowlevel import current_trio_token # type: ignore[import-not-found]
35+
except ImportError:
36+
trio = None # type: ignore[assignment]
37+
current_trio_token = None # type: ignore[assignment]
38+
39+
try:
40+
import curio # type: ignore[import-not-found,import-untyped]
41+
except (ImportError, AttributeError):
42+
curio = None # type: ignore[assignment]
2343

2444
logger = logging.getLogger(__name__)
2545

@@ -179,7 +199,7 @@ def find_free_port() -> int:
179199
return s.getsockname()[1]
180200

181201

182-
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
202+
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
183203
"""
184204
Apply arguments to a function.
185205
@@ -188,6 +208,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188208
it is possible to pass arguments by name, and falling back to positional
189209
arguments if not.
190210
211+
This function supports both synchronous and asynchronous callables. If the
212+
callable is a coroutine function, it will be executed in an event loop.
213+
191214
Args:
192215
f:
193216
The function to apply the arguments to.
@@ -200,6 +223,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200223
Returns:
201224
The result of the function.
202225
"""
226+
# Check if f is a partial wrapping an async function
227+
func_to_check = f.func if isinstance(f, partial) else f
228+
is_async = inspect.iscoroutinefunction(func_to_check)
203229
signature = inspect.signature(f)
204230
f_name = (
205231
f.__qualname__
@@ -226,7 +252,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226252
# First, we inspect the keyword arguments and try and pass in some arguments
227253
# by currying them in.
228254
for param in signature.parameters.values():
229-
if param.name not in args:
255+
# Try matching the parameter name, or if it starts with underscore,
256+
# also try matching without the leading underscore.
257+
arg_key = None
258+
if param.name in args:
259+
arg_key = param.name
260+
elif (
261+
param.name.startswith("_")
262+
and len(param.name) > 1
263+
and param.name[1:] in args
264+
):
265+
arg_key = param.name[1:]
266+
267+
if arg_key is None:
230268
# If a parameter is not known, we will ignore it.
231269
#
232270
# If the ignored parameter doesn't have a default value, it will
@@ -246,12 +284,13 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246284
if param.kind in positional_match:
247285
# We iterate through the parameters in order that they are defined,
248286
# making it fine to pass them in by position one at a time.
249-
f = partial(f, args[param.name])
250-
del args[param.name]
287+
f = partial(f, args[arg_key])
288+
del args[arg_key]
289+
continue
251290

252291
if param.kind in keyword_match:
253-
f = partial(f, **{param.name: args[param.name]})
254-
del args[param.name]
292+
f = partial(f, **{param.name: args[arg_key]})
293+
del args[arg_key]
255294
continue
256295

257296
# At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +321,147 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282321
)
283322

284323
try:
324+
if is_async:
325+
result = f()
326+
if inspect.iscoroutine(result):
327+
return _run_async_coroutine(result)
328+
return result
285329
return f()
286330
except Exception:
287331
logger.exception("Error occurred while calling function %s", f_name)
288332
raise
333+
334+
335+
def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
336+
"""
337+
Run a coroutine in an event loop.
338+
339+
Detects the async runtime (asyncio, trio, or curio) and executes the
340+
coroutine appropriately. Preserves ContextVars when creating a new event
341+
loop, which is important when handlers are called from threads.
342+
343+
Args:
344+
coro:
345+
The coroutine to run.
346+
347+
Returns:
348+
The result of the coroutine.
349+
350+
Raises:
351+
RuntimeError:
352+
If the detected runtime (trio or curio) is not installed.
353+
"""
354+
runtime = _detect_async_runtime_from_coroutine(coro)
355+
356+
if runtime == "trio":
357+
if trio is None:
358+
msg = "trio is not installed"
359+
raise RuntimeError(msg)
360+
361+
if current_trio_token is not None:
362+
try:
363+
token = current_trio_token()
364+
365+
async def _run_with_token() -> _T:
366+
return await coro
367+
368+
return trio.from_thread.run_sync(_run_with_token, trio_token=token) # type: ignore[return-value]
369+
except RuntimeError:
370+
pass
371+
372+
ctx = contextvars.copy_context()
373+
374+
async def _run_trio() -> _T:
375+
return await coro
376+
377+
return ctx.run(trio.run, _run_trio)
378+
379+
if runtime == "curio":
380+
if curio is None:
381+
msg = "curio is not installed"
382+
raise RuntimeError(msg)
383+
384+
try:
385+
return curio.AWAIT(coro)
386+
except RuntimeError:
387+
ctx = contextvars.copy_context()
388+
389+
async def _run_curio() -> _T:
390+
return await coro
391+
392+
return ctx.run(curio.run, _run_curio)
393+
394+
try:
395+
loop = asyncio.get_running_loop()
396+
except RuntimeError:
397+
loop = None
398+
399+
if loop is not None:
400+
future: asyncio.Future[_T] = asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type,assignment]
401+
return future.result()
402+
403+
ctx = contextvars.copy_context()
404+
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]
405+
406+
407+
def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str: # noqa: C901
408+
"""
409+
Detect async runtime by inspecting the coroutine object.
410+
411+
Args:
412+
coro:
413+
The coroutine object to inspect.
414+
415+
Returns:
416+
The detected runtime: "asyncio", "trio", or "curio".
417+
"""
418+
if sniffio is not None:
419+
try:
420+
return sniffio.current_async_library()
421+
except sniffio.AsyncLibraryNotFoundError:
422+
pass
423+
424+
# Inspect bytecode to check for qualified attribute access (e.g., trio.sleep)
425+
# This is more robust than just checking co_names for module and method separately
426+
func_code = coro.cr_code # type: ignore[attr-defined]
427+
428+
# Parse bytecode to find LOAD_GLOBAL/LOAD_NAME followed by LOAD_ATTR patterns
429+
# This detects qualified accesses like `trio.sleep()` or `curio.spawn()`
430+
bytecode = list(dis.get_instructions(func_code))
431+
432+
trio_detected = False
433+
curio_detected = False
434+
435+
for i, instr in enumerate(bytecode):
436+
# Check for module.attribute pattern (LOAD_GLOBAL/LOAD_NAME + LOAD_ATTR)
437+
if instr.opname in ("LOAD_GLOBAL", "LOAD_NAME") and i + 1 < len(bytecode):
438+
next_instr = bytecode[i + 1]
439+
if next_instr.opname == "LOAD_ATTR":
440+
module_name = instr.argval
441+
attr_name = next_instr.argval
442+
443+
# Check for trio-specific qualified access
444+
if module_name == "trio":
445+
trio_indicators = {
446+
"sleep",
447+
"open_nursery",
448+
"CancelScope",
449+
"current_trio_token",
450+
}
451+
if attr_name in trio_indicators:
452+
trio_detected = True
453+
454+
# Check for curio-specific qualified access
455+
elif module_name == "curio":
456+
curio_indicators = {"sleep", "spawn", "TaskGroup", "AWAIT"}
457+
if attr_name in curio_indicators:
458+
curio_detected = True
459+
460+
# Trio takes precedence if both are detected
461+
if trio_detected:
462+
return "trio"
463+
if curio_detected:
464+
return "curio"
465+
466+
# Default to asyncio as it's the most common
467+
return "asyncio"

0 commit comments

Comments
 (0)