Skip to content

Commit 868f689

Browse files
Enable async state/message handlers (asyncio/trio/curio) with ContextVar preservation and tests
1 parent 02562fe commit 868f689

File tree

2 files changed

+603
-8
lines changed

2 files changed

+603
-8
lines changed

src/pact/_util.py

Lines changed: 157 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,39 @@
99

1010
from __future__ import annotations
1111

12+
import asyncio
13+
import contextvars
1214
import inspect
1315
import logging
1416
import socket
1517
import warnings
1618
from contextlib import closing
1719
from functools import partial
1820
from inspect import Parameter, _ParameterKind
19-
from typing import TYPE_CHECKING, TypeVar
21+
from typing import TYPE_CHECKING, Any, TypeVar
2022

2123
if TYPE_CHECKING:
22-
from collections.abc import Callable, Mapping
24+
from collections.abc import Callable, Coroutine, Mapping
25+
26+
try:
27+
import sniffio # type: ignore[import-not-found]
28+
except ImportError:
29+
sniffio = None # type: ignore[assignment]
30+
31+
try:
32+
import trio # type: ignore[import-not-found]
33+
from trio.lowlevel import current_trio_token # type: ignore[import-not-found]
34+
except ImportError:
35+
trio = None # type: ignore[assignment]
36+
current_trio_token = None # type: ignore[assignment]
37+
38+
try:
39+
import curio # type: ignore[import-not-found,import-untyped]
40+
except (ImportError, AttributeError):
41+
curio = None # type: ignore[assignment]
42+
43+
if TYPE_CHECKING:
44+
from collections.abc import Callable, Coroutine, Mapping
2345

2446
logger = logging.getLogger(__name__)
2547

@@ -179,7 +201,7 @@ def find_free_port() -> int:
179201
return s.getsockname()[1]
180202

181203

182-
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
204+
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
183205
"""
184206
Apply arguments to a function.
185207
@@ -188,6 +210,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188210
it is possible to pass arguments by name, and falling back to positional
189211
arguments if not.
190212
213+
This function supports both synchronous and asynchronous callables. If the
214+
callable is a coroutine function, it will be executed in an event loop.
215+
191216
Args:
192217
f:
193218
The function to apply the arguments to.
@@ -200,6 +225,7 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200225
Returns:
201226
The result of the function.
202227
"""
228+
is_async = inspect.iscoroutinefunction(f)
203229
signature = inspect.signature(f)
204230
f_name = (
205231
f.__qualname__
@@ -226,7 +252,15 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226252
# First, we inspect the keyword arguments and try and pass in some arguments
227253
# by currying them in.
228254
for param in signature.parameters.values():
229-
if param.name not in args:
255+
# Try matching the parameter name, or if it starts with underscore,
256+
# also try matching without the leading underscore.
257+
arg_key = None
258+
if param.name in args:
259+
arg_key = param.name
260+
elif param.name.startswith("_") and param.name[1:] in args:
261+
arg_key = param.name[1:]
262+
263+
if arg_key is None:
230264
# If a parameter is not known, we will ignore it.
231265
#
232266
# If the ignored parameter doesn't have a default value, it will
@@ -246,12 +280,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246280
if param.kind in positional_match:
247281
# We iterate through the parameters in order that they are defined,
248282
# making it fine to pass them in by position one at a time.
249-
f = partial(f, args[param.name])
250-
del args[param.name]
283+
f = partial(f, args[arg_key])
284+
del args[arg_key]
251285

252286
if param.kind in keyword_match:
253-
f = partial(f, **{param.name: args[param.name]})
254-
del args[param.name]
287+
f = partial(f, **{param.name: args[arg_key]})
288+
del args[arg_key]
255289
continue
256290

257291
# At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +316,122 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282316
)
283317

284318
try:
319+
if is_async:
320+
result = f()
321+
if inspect.iscoroutine(result):
322+
return _run_async_coroutine(result)
323+
return result
285324
return f()
286325
except Exception:
287326
logger.exception("Error occurred while calling function %s", f_name)
288327
raise
328+
329+
330+
def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
331+
"""
332+
Run a coroutine in an event loop.
333+
334+
Detects the async runtime (asyncio, trio, or curio) and executes the
335+
coroutine appropriately. Preserves ContextVars when creating a new event
336+
loop, which is important when handlers are called from threads.
337+
338+
Args:
339+
coro:
340+
The coroutine to run.
341+
342+
Returns:
343+
The result of the coroutine.
344+
345+
Raises:
346+
RuntimeError:
347+
If the detected runtime is not installed.
348+
"""
349+
runtime = _detect_async_runtime_from_coroutine(coro)
350+
351+
if runtime == "trio":
352+
if trio is None:
353+
msg = "trio is not installed"
354+
raise RuntimeError(msg)
355+
356+
if current_trio_token is not None:
357+
try:
358+
token = current_trio_token()
359+
360+
async def _run_with_token() -> _T:
361+
return await coro
362+
363+
return trio.from_thread.run_sync(_run_with_token, trio_token=token) # type: ignore[return-value]
364+
except RuntimeError:
365+
pass
366+
367+
ctx = contextvars.copy_context()
368+
369+
async def _run_trio() -> _T:
370+
return await coro
371+
372+
return ctx.run(trio.run, _run_trio)
373+
374+
if runtime == "curio":
375+
if curio is None:
376+
msg = "curio is not installed"
377+
raise RuntimeError(msg)
378+
379+
try:
380+
return curio.AWAIT(coro)
381+
except RuntimeError:
382+
ctx = contextvars.copy_context()
383+
384+
async def _run_curio() -> _T:
385+
return await coro
386+
387+
return ctx.run(curio.run, _run_curio)
388+
389+
try:
390+
loop = asyncio.get_running_loop()
391+
except RuntimeError:
392+
loop = None
393+
394+
if loop is not None:
395+
future: asyncio.Future[_T] = asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type,assignment]
396+
return future.result()
397+
398+
ctx = contextvars.copy_context()
399+
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]
400+
401+
402+
def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str:
403+
"""
404+
Detect async runtime by inspecting the coroutine object.
405+
406+
Args:
407+
coro:
408+
The coroutine object to inspect.
409+
410+
Returns:
411+
The detected runtime: "asyncio", "trio", or "curio".
412+
"""
413+
if sniffio is not None:
414+
try:
415+
return sniffio.current_async_library()
416+
except sniffio.AsyncLibraryNotFoundError:
417+
pass
418+
419+
# Inspect the coroutine's code and globals to determine runtime
420+
func_code = coro.cr_code # type: ignore[attr-defined]
421+
code_names = set(func_code.co_names)
422+
423+
# Check if trio or curio modules are actually used in the code
424+
# If the module name is in co_names, it's being accessed (e.g., trio.sleep)
425+
# We also verify with specific indicators to avoid false positives
426+
if "trio" in code_names:
427+
trio_indicators = {"open_nursery", "current_trio_token", "CancelScope", "sleep"}
428+
if any(indicator in code_names for indicator in trio_indicators):
429+
return "trio"
430+
431+
if "curio" in code_names:
432+
curio_indicators = {"TaskGroup", "spawn", "AWAIT", "sleep"}
433+
if any(indicator in code_names for indicator in curio_indicators):
434+
return "curio"
435+
436+
# Default to asyncio as it's the most common
437+
return "asyncio"

0 commit comments

Comments
 (0)