Skip to content

Commit 3ccc083

Browse files
Enable async state/message handlers (asyncio/trio/curio) with ContextVar preservation and tests
1 parent 02562fe commit 3ccc083

File tree

2 files changed

+723
-8
lines changed

2 files changed

+723
-8
lines changed

src/pact/_util.py

Lines changed: 166 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,39 @@
99

1010
from __future__ import annotations
1111

12+
import asyncio
13+
import contextvars
1214
import inspect
1315
import logging
1416
import socket
1517
import warnings
1618
from contextlib import closing
1719
from functools import partial
1820
from inspect import Parameter, _ParameterKind
19-
from typing import TYPE_CHECKING, TypeVar
21+
from typing import TYPE_CHECKING, Any, TypeVar
2022

2123
if TYPE_CHECKING:
22-
from collections.abc import Callable, Mapping
24+
from collections.abc import Callable, Coroutine, Mapping
25+
26+
try:
27+
import sniffio # type: ignore[import-not-found]
28+
except ImportError:
29+
sniffio = None # type: ignore[assignment]
30+
31+
try:
32+
import trio # type: ignore[import-not-found]
33+
from trio.lowlevel import current_trio_token # type: ignore[import-not-found]
34+
except ImportError:
35+
trio = None # type: ignore[assignment]
36+
current_trio_token = None # type: ignore[assignment]
37+
38+
try:
39+
import curio # type: ignore[import-not-found,import-untyped]
40+
except (ImportError, AttributeError):
41+
curio = None # type: ignore[assignment]
42+
43+
if TYPE_CHECKING:
44+
from collections.abc import Callable, Coroutine, Mapping
2345

2446
logger = logging.getLogger(__name__)
2547

@@ -179,7 +201,7 @@ def find_free_port() -> int:
179201
return s.getsockname()[1]
180202

181203

182-
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
204+
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
183205
"""
184206
Apply arguments to a function.
185207
@@ -188,6 +210,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188210
it is possible to pass arguments by name, and falling back to positional
189211
arguments if not.
190212
213+
This function supports both synchronous and asynchronous callables. If the
214+
callable is a coroutine function, it will be executed in an event loop.
215+
191216
Args:
192217
f:
193218
The function to apply the arguments to.
@@ -200,6 +225,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200225
Returns:
201226
The result of the function.
202227
"""
228+
# Check if f is a partial wrapping an async function
229+
func_to_check = f.func if isinstance(f, partial) else f
230+
is_async = inspect.iscoroutinefunction(func_to_check)
203231
signature = inspect.signature(f)
204232
f_name = (
205233
f.__qualname__
@@ -226,7 +254,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226254
# First, we inspect the keyword arguments and try and pass in some arguments
227255
# by currying them in.
228256
for param in signature.parameters.values():
229-
if param.name not in args:
257+
# Try matching the parameter name, or if it starts with underscore,
258+
# also try matching without the leading underscore.
259+
arg_key = None
260+
if param.name in args:
261+
arg_key = param.name
262+
elif (
263+
param.name.startswith("_")
264+
and len(param.name) > 1
265+
and param.name[1:] in args
266+
):
267+
arg_key = param.name[1:]
268+
269+
if arg_key is None:
230270
# If a parameter is not known, we will ignore it.
231271
#
232272
# If the ignored parameter doesn't have a default value, it will
@@ -246,12 +286,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246286
if param.kind in positional_match:
247287
# We iterate through the parameters in order that they are defined,
248288
# making it fine to pass them in by position one at a time.
249-
f = partial(f, args[param.name])
250-
del args[param.name]
289+
f = partial(f, args[arg_key])
290+
del args[arg_key]
251291

252292
if param.kind in keyword_match:
253-
f = partial(f, **{param.name: args[param.name]})
254-
del args[param.name]
293+
f = partial(f, **{param.name: args[arg_key]})
294+
del args[arg_key]
255295
continue
256296

257297
# At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +322,125 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282322
)
283323

284324
try:
325+
if is_async:
326+
result = f()
327+
if inspect.iscoroutine(result):
328+
return _run_async_coroutine(result)
329+
return result
285330
return f()
286331
except Exception:
287332
logger.exception("Error occurred while calling function %s", f_name)
288333
raise
334+
335+
336+
def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
337+
"""
338+
Run a coroutine in an event loop.
339+
340+
Detects the async runtime (asyncio, trio, or curio) and executes the
341+
coroutine appropriately. Preserves ContextVars when creating a new event
342+
loop, which is important when handlers are called from threads.
343+
344+
Args:
345+
coro:
346+
The coroutine to run.
347+
348+
Returns:
349+
The result of the coroutine.
350+
351+
Raises:
352+
RuntimeError:
353+
If the detected async runtime (trio or curio) is not installed.
354+
RuntimeError:
355+
If the detected runtime is not installed.
356+
"""
357+
runtime = _detect_async_runtime_from_coroutine(coro)
358+
359+
if runtime == "trio":
360+
if trio is None:
361+
msg = "trio is not installed"
362+
raise RuntimeError(msg)
363+
364+
if current_trio_token is not None:
365+
try:
366+
token = current_trio_token()
367+
368+
async def _run_with_token() -> _T:
369+
return await coro
370+
371+
return trio.from_thread.run_sync(_run_with_token, trio_token=token) # type: ignore[return-value]
372+
except RuntimeError:
373+
pass
374+
375+
ctx = contextvars.copy_context()
376+
377+
async def _run_trio() -> _T:
378+
return await coro
379+
380+
return ctx.run(trio.run, _run_trio)
381+
382+
if runtime == "curio":
383+
if curio is None:
384+
msg = "curio is not installed"
385+
raise RuntimeError(msg)
386+
387+
try:
388+
return curio.AWAIT(coro)
389+
except RuntimeError:
390+
ctx = contextvars.copy_context()
391+
392+
async def _run_curio() -> _T:
393+
return await coro
394+
395+
return ctx.run(curio.run, _run_curio)
396+
397+
try:
398+
loop = asyncio.get_running_loop()
399+
except RuntimeError:
400+
loop = None
401+
402+
if loop is not None:
403+
future: asyncio.Future[_T] = asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type,assignment]
404+
return future.result()
405+
406+
ctx = contextvars.copy_context()
407+
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]
408+
409+
410+
def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str:
411+
"""
412+
Detect async runtime by inspecting the coroutine object.
413+
414+
Args:
415+
coro:
416+
The coroutine object to inspect.
417+
418+
Returns:
419+
The detected runtime: "asyncio", "trio", or "curio".
420+
"""
421+
if sniffio is not None:
422+
try:
423+
return sniffio.current_async_library()
424+
except sniffio.AsyncLibraryNotFoundError:
425+
pass
426+
427+
# Inspect the coroutine's code and globals to determine runtime
428+
func_code = coro.cr_code # type: ignore[attr-defined]
429+
code_names = set(func_code.co_names)
430+
431+
# Check if trio or curio modules are actually used in the code
432+
# If the module name is in co_names, it's being accessed (e.g., trio.sleep)
433+
# We also verify with specific indicators to avoid false positives
434+
# Note: If both trio and curio indicators are found, trio takes precedence
435+
if "trio" in code_names:
436+
trio_indicators = {"open_nursery", "current_trio_token", "CancelScope", "sleep"}
437+
if any(indicator in code_names for indicator in trio_indicators):
438+
return "trio"
439+
440+
if "curio" in code_names:
441+
curio_indicators = {"TaskGroup", "spawn", "AWAIT", "sleep"}
442+
if any(indicator in code_names for indicator in curio_indicators):
443+
return "curio"
444+
445+
# Default to asyncio as it's the most common
446+
return "asyncio"

0 commit comments

Comments
 (0)