Skip to content

Commit 770638f

Browse files
Enable async state/message handlers (asyncio/trio/curio) with ContextVar preservation and tests
1 parent 02562fe commit 770638f

File tree

2 files changed

+583
-6
lines changed

2 files changed

+583
-6
lines changed

src/pact/_util.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
from __future__ import annotations
1111

12+
import asyncio
13+
import contextvars
1214
import inspect
1315
import logging
1416
import socket
@@ -18,6 +20,23 @@
1820
from inspect import Parameter, _ParameterKind
1921
from typing import TYPE_CHECKING, TypeVar
2022

23+
try:
24+
import sniffio
25+
except ImportError:
26+
sniffio = None # type: ignore[assignment]
27+
28+
try:
29+
import trio
30+
from trio.lowlevel import current_trio_token
31+
except ImportError:
32+
trio = None # type: ignore[assignment]
33+
current_trio_token = None # type: ignore[assignment]
34+
35+
try:
36+
import curio
37+
except (ImportError, AttributeError):
38+
curio = None # type: ignore[assignment]
39+
2140
if TYPE_CHECKING:
2241
from collections.abc import Callable, Mapping
2342

@@ -179,7 +198,7 @@ def find_free_port() -> int:
179198
return s.getsockname()[1]
180199

181200

182-
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901
201+
def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa: C901, PLR0912
183202
"""
184203
Apply arguments to a function.
185204
@@ -188,6 +207,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188207
it is possible to pass arguments by name, and falling back to positional
189208
arguments if not.
190209
210+
This function supports both synchronous and asynchronous callables. If the
211+
callable is a coroutine function, it will be executed in an event loop.
212+
191213
Args:
192214
f:
193215
The function to apply the arguments to.
@@ -200,6 +222,7 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200222
Returns:
201223
The result of the function.
202224
"""
225+
is_async = inspect.iscoroutinefunction(f)
203226
signature = inspect.signature(f)
204227
f_name = (
205228
f.__qualname__
@@ -226,7 +249,15 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226249
# First, we inspect the keyword arguments and try and pass in some arguments
227250
# by currying them in.
228251
for param in signature.parameters.values():
229-
if param.name not in args:
252+
# Try matching the parameter name, or if it starts with underscore,
253+
# also try matching without the leading underscore.
254+
arg_key = None
255+
if param.name in args:
256+
arg_key = param.name
257+
elif param.name.startswith("_") and param.name[1:] in args:
258+
arg_key = param.name[1:]
259+
260+
if arg_key is None:
230261
# If a parameter is not known, we will ignore it.
231262
#
232263
# If the ignored parameter doesn't have a default value, it will
@@ -246,12 +277,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246277
if param.kind in positional_match:
247278
# We iterate through the parameters in order that they are defined,
248279
# making it fine to pass them in by position one at a time.
249-
f = partial(f, args[param.name])
250-
del args[param.name]
280+
f = partial(f, args[arg_key])
281+
del args[arg_key]
251282

252283
if param.kind in keyword_match:
253-
f = partial(f, **{param.name: args[param.name]})
254-
del args[param.name]
284+
f = partial(f, **{param.name: args[arg_key]})
285+
del args[arg_key]
255286
continue
256287

257288
# At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +313,107 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282313
)
283314

284315
try:
316+
if is_async:
317+
return _run_async(f)
285318
return f()
286319
except Exception:
287320
logger.exception("Error occurred while calling function %s", f_name)
288321
raise
322+
323+
324+
def _run_async(coro_func: Callable[..., _T]) -> _T:
325+
"""
326+
Run an async function in an event loop.
327+
328+
Detects the async runtime (asyncio, trio, or curio) and executes the
329+
coroutine appropriately. Preserves ContextVars when creating a new event
330+
loop, which is important when handlers are called from threads.
331+
332+
Args:
333+
coro_func:
334+
The async function to run (already partially applied).
335+
336+
Returns:
337+
The result of the async function.
338+
339+
Raises:
340+
RuntimeError:
341+
If the detected runtime is not installed.
342+
"""
343+
runtime = _detect_async_runtime_from_coro(coro_func)
344+
345+
if runtime == "trio":
346+
if trio is None:
347+
msg = "trio is not installed"
348+
raise RuntimeError(msg)
349+
350+
if current_trio_token is not None:
351+
try:
352+
token = current_trio_token()
353+
return trio.from_thread.run_sync(coro_func, trio_token=token)
354+
except RuntimeError:
355+
pass
356+
357+
ctx = contextvars.copy_context()
358+
return ctx.run(trio.run, coro_func)
359+
360+
if runtime == "curio":
361+
if curio is None:
362+
msg = "curio is not installed"
363+
raise RuntimeError(msg)
364+
365+
try:
366+
return curio.AWAIT(coro_func())
367+
except RuntimeError:
368+
ctx = contextvars.copy_context()
369+
return ctx.run(curio.run, coro_func)
370+
371+
try:
372+
loop = asyncio.get_running_loop()
373+
except RuntimeError:
374+
loop = None
375+
376+
if loop is not None:
377+
future = asyncio.run_coroutine_threadsafe(coro_func(), loop)
378+
return future.result()
379+
380+
ctx = contextvars.copy_context()
381+
return ctx.run(asyncio.run, coro_func())
382+
383+
384+
def _detect_async_runtime_from_coro(coro_func: Callable[..., _T]) -> str:
385+
"""
386+
Detect async runtime by inspecting the coroutine function.
387+
388+
Args:
389+
coro_func:
390+
The async function to inspect. May be a partial object.
391+
392+
Returns:
393+
The detected runtime: "asyncio", "trio", or "curio".
394+
"""
395+
if sniffio is not None:
396+
try:
397+
return sniffio.current_async_library()
398+
except sniffio.AsyncLibraryNotFoundError:
399+
pass
400+
401+
func = getattr(coro_func, "func", coro_func)
402+
func_code = getattr(func, "__code__", None)
403+
func_globals = getattr(func, "__globals__", {})
404+
405+
if func_code is not None:
406+
code_names = set(func_code.co_names)
407+
408+
trio_indicators = {"trio", "open_nursery", "current_trio_token"}
409+
curio_indicators = {"curio", "TaskGroup", "AWAIT"}
410+
411+
uses_trio = bool(trio_indicators & code_names)
412+
uses_curio = bool(curio_indicators & code_names)
413+
414+
if uses_trio and "trio" in func_globals:
415+
return "trio"
416+
if uses_curio and "curio" in func_globals:
417+
return "curio"
418+
419+
return "asyncio"

0 commit comments

Comments
 (0)