Skip to content

Commit 971f06f

Browse files
Enable async state/message handlers (asyncio/trio/curio) with ContextVar preservation and tests
1 parent 02562fe commit 971f06f

File tree

2 files changed

+1007
-8
lines changed

2 files changed

+1007
-8
lines changed

src/pact/_util.py

Lines changed: 161 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,36 @@
99

1010
from __future__ import annotations
1111

12+
import asyncio
13+
import contextvars
1214
import inspect
1315
import logging
1416
import socket
1517
import warnings
1618
from contextlib import closing
1719
from functools import partial
1820
from inspect import Parameter, _ParameterKind
19-
from typing import TYPE_CHECKING, TypeVar
21+
from typing import TYPE_CHECKING, Any, TypeVar
2022

2123
if TYPE_CHECKING:
22-
from collections.abc import Callable, Mapping
24+
from collections.abc import Callable, Coroutine, Mapping
25+
26+
try:
27+
import sniffio # type: ignore[import-not-found]
28+
except ImportError:
29+
sniffio = None # type: ignore[assignment]
30+
31+
try:
32+
import trio # type: ignore[import-not-found]
33+
from trio.lowlevel import current_trio_token # type: ignore[import-not-found]
34+
except ImportError:
35+
trio = None # type: ignore[assignment]
36+
current_trio_token = None # type: ignore[assignment]
37+
38+
try:
39+
import curio # type: ignore[import-not-found,import-untyped]
40+
except (ImportError, AttributeError):
41+
curio = None # type: ignore[assignment]
2342

2443
logger = logging.getLogger(__name__)
2544

@@ -179,7 +198,7 @@ def find_free_port() -> int:
179198
return s.getsockname()[1]
180199

181200

182-
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
201+
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
183202
"""
184203
Apply arguments to a function.
185204
@@ -188,6 +207,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188207
it is possible to pass arguments by name, and falling back to positional
189208
arguments if not.
190209
210+
This function supports both synchronous and asynchronous callables. If the
211+
callable is a coroutine function, it will be executed in an event loop.
212+
191213
Args:
192214
f:
193215
The function to apply the arguments to.
@@ -200,6 +222,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200222
Returns:
201223
The result of the function.
202224
"""
225+
# Check if f is a partial wrapping an async function
226+
func_to_check = f.func if isinstance(f, partial) else f
227+
is_async = inspect.iscoroutinefunction(func_to_check)
203228
signature = inspect.signature(f)
204229
f_name = (
205230
f.__qualname__
@@ -226,7 +251,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226251
# First, we inspect the keyword arguments and try and pass in some arguments
227252
# by currying them in.
228253
for param in signature.parameters.values():
229-
if param.name not in args:
254+
# Try matching the parameter name, or if it starts with underscore,
255+
# also try matching without the leading underscore.
256+
arg_key = None
257+
if param.name in args:
258+
arg_key = param.name
259+
elif (
260+
param.name.startswith("_")
261+
and len(param.name) > 1
262+
and param.name[1:] in args
263+
):
264+
arg_key = param.name[1:]
265+
266+
if arg_key is None:
230267
# If a parameter is not known, we will ignore it.
231268
#
232269
# If the ignored parameter doesn't have a default value, it will
@@ -246,12 +283,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246283
if param.kind in positional_match:
247284
# We iterate through the parameters in order that they are defined,
248285
# making it fine to pass them in by position one at a time.
249-
f = partial(f, args[param.name])
250-
del args[param.name]
286+
f = partial(f, args[arg_key])
287+
del args[arg_key]
251288

252289
if param.kind in keyword_match:
253-
f = partial(f, **{param.name: args[param.name]})
254-
del args[param.name]
290+
f = partial(f, **{param.name: args[arg_key]})
291+
del args[arg_key]
255292
continue
256293

257294
# At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +319,123 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282319
)
283320

284321
try:
322+
if is_async:
323+
result = f()
324+
if inspect.iscoroutine(result):
325+
return _run_async_coroutine(result)
326+
return result
285327
return f()
286328
except Exception:
287329
logger.exception("Error occurred while calling function %s", f_name)
288330
raise
331+
332+
333+
def _run_async_coroutine(coro: Coroutine[Any, Any, _T]) -> _T: # noqa: C901
334+
"""
335+
Run a coroutine in an event loop.
336+
337+
Detects the async runtime (asyncio, trio, or curio) and executes the
338+
coroutine appropriately. Preserves ContextVars when creating a new event
339+
loop, which is important when handlers are called from threads.
340+
341+
Args:
342+
coro:
343+
The coroutine to run.
344+
345+
Returns:
346+
The result of the coroutine.
347+
348+
Raises:
349+
RuntimeError:
350+
If the detected runtime (trio or curio) is not installed.
351+
"""
352+
runtime = _detect_async_runtime_from_coroutine(coro)
353+
354+
if runtime == "trio":
355+
if trio is None:
356+
msg = "trio is not installed"
357+
raise RuntimeError(msg)
358+
359+
if current_trio_token is not None:
360+
try:
361+
token = current_trio_token()
362+
363+
async def _run_with_token() -> _T:
364+
return await coro
365+
366+
return trio.from_thread.run_sync(_run_with_token, trio_token=token) # type: ignore[return-value]
367+
except RuntimeError:
368+
pass
369+
370+
ctx = contextvars.copy_context()
371+
372+
async def _run_trio() -> _T:
373+
return await coro
374+
375+
return ctx.run(trio.run, _run_trio)
376+
377+
if runtime == "curio":
378+
if curio is None:
379+
msg = "curio is not installed"
380+
raise RuntimeError(msg)
381+
382+
try:
383+
return curio.AWAIT(coro)
384+
except RuntimeError:
385+
ctx = contextvars.copy_context()
386+
387+
async def _run_curio() -> _T:
388+
return await coro
389+
390+
return ctx.run(curio.run, _run_curio)
391+
392+
try:
393+
loop = asyncio.get_running_loop()
394+
except RuntimeError:
395+
loop = None
396+
397+
if loop is not None:
398+
future: asyncio.Future[_T] = asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type,assignment]
399+
return future.result()
400+
401+
ctx = contextvars.copy_context()
402+
return ctx.run(asyncio.run, coro) # type: ignore[arg-type,return-value]
403+
404+
405+
def _detect_async_runtime_from_coroutine(coro: Coroutine[Any, Any, _T]) -> str:
406+
"""
407+
Detect async runtime by inspecting the coroutine object.
408+
409+
Args:
410+
coro:
411+
The coroutine object to inspect.
412+
413+
Returns:
414+
The detected runtime: "asyncio", "trio", or "curio".
415+
"""
416+
if sniffio is not None:
417+
try:
418+
return sniffio.current_async_library()
419+
except sniffio.AsyncLibraryNotFoundError:
420+
pass
421+
422+
# Inspect the coroutine's code and globals to determine runtime
423+
func_code = coro.cr_code # type: ignore[attr-defined]
424+
code_names = set(func_code.co_names)
425+
426+
# Check if trio or curio modules are actually used in the code
427+
# If the module name is in co_names, it's being accessed (e.g., trio.sleep)
428+
# We also verify with specific indicators to avoid false positives
429+
# Note: If both trio and curio indicators are found, trio takes precedence
430+
if "trio" in code_names:
431+
trio_indicators = {"open_nursery", "current_trio_token", "CancelScope", "sleep"}
432+
if any(indicator in code_names for indicator in trio_indicators):
433+
return "trio"
434+
435+
if "curio" in code_names:
436+
curio_indicators = {"TaskGroup", "spawn", "AWAIT", "sleep"}
437+
if any(indicator in code_names for indicator in curio_indicators):
438+
return "curio"
439+
440+
# Default to asyncio as it's the most common
441+
return "asyncio"

0 commit comments

Comments
 (0)