99
1010from __future__ import annotations
1111
12+ import asyncio
13+ import contextvars
1214import inspect
1315import logging
1416import socket
1517import warnings
1618from contextlib import closing
1719from functools import partial
1820from inspect import Parameter , _ParameterKind
19- from typing import TYPE_CHECKING , TypeVar
21+ from typing import TYPE_CHECKING , Any , TypeVar
2022
2123if TYPE_CHECKING :
22- from collections .abc import Callable , Mapping
24+ from collections .abc import Callable , Coroutine , Mapping
25+
26+ try :
27+ import sniffio # type: ignore[import-not-found]
28+ except ImportError :
29+ sniffio = None # type: ignore[assignment]
30+
31+ try :
32+ import trio # type: ignore[import-not-found]
33+ from trio .lowlevel import current_trio_token # type: ignore[import-not-found]
34+ except ImportError :
35+ trio = None # type: ignore[assignment]
36+ current_trio_token = None # type: ignore[assignment]
37+
38+ try :
39+ import curio # type: ignore[import-not-found,import-untyped]
40+ except (ImportError , AttributeError ):
41+ curio = None # type: ignore[assignment]
42+
43+ if TYPE_CHECKING :
44+ from collections .abc import Callable , Coroutine , Mapping
2345
2446logger = logging .getLogger (__name__ )
2547
@@ -179,7 +201,7 @@ def find_free_port() -> int:
179201 return s .getsockname ()[1 ]
180202
181203
182- def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901
204+ def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901, PLR0912
183205 """
184206 Apply arguments to a function.
185207
@@ -188,6 +210,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188210 it is possible to pass arguments by name, and falling back to positional
189211 arguments if not.
190212
213+ This function supports both synchronous and asynchronous callables. If the
214+ callable is a coroutine function, it will be executed in an event loop.
215+
191216 Args:
192217 f:
193218 The function to apply the arguments to.
@@ -200,6 +225,7 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200225 Returns:
201226 The result of the function.
202227 """
228+ is_async = inspect .iscoroutinefunction (f )
203229 signature = inspect .signature (f )
204230 f_name = (
205231 f .__qualname__
@@ -226,7 +252,15 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226252 # First, we inspect the keyword arguments and try and pass in some arguments
227253 # by currying them in.
228254 for param in signature .parameters .values ():
229- if param .name not in args :
255+ # Try matching the parameter name, or if it starts with underscore,
256+ # also try matching without the leading underscore.
257+ arg_key = None
258+ if param .name in args :
259+ arg_key = param .name
260+ elif param .name .startswith ("_" ) and param .name [1 :] in args :
261+ arg_key = param .name [1 :]
262+
263+ if arg_key is None :
230264 # If a parameter is not known, we will ignore it.
231265 #
232266 # If the ignored parameter doesn't have a default value, it will
@@ -246,12 +280,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246280 if param .kind in positional_match :
247281 # We iterate through the parameters in order that they are defined,
248282 # making it fine to pass them in by position one at a time.
249- f = partial (f , args [param . name ])
250- del args [param . name ]
283+ f = partial (f , args [arg_key ])
284+ del args [arg_key ]
251285
252286 if param .kind in keyword_match :
253- f = partial (f , ** {param .name : args [param . name ]})
254- del args [param . name ]
287+ f = partial (f , ** {param .name : args [arg_key ]})
288+ del args [arg_key ]
255289 continue
256290
257291 # At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +316,122 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282316 )
283317
284318 try :
319+ if is_async :
320+ result = f ()
321+ if inspect .iscoroutine (result ):
322+ return _run_async_coroutine (result )
323+ return result
285324 return f ()
286325 except Exception :
287326 logger .exception ("Error occurred while calling function %s" , f_name )
288327 raise
328+
329+
330+ def _run_async_coroutine (coro : Coroutine [Any , Any , _T ]) -> _T : # noqa: C901
331+ """
332+ Run a coroutine in an event loop.
333+
334+ Detects the async runtime (asyncio, trio, or curio) and executes the
335+ coroutine appropriately. Preserves ContextVars when creating a new event
336+ loop, which is important when handlers are called from threads.
337+
338+ Args:
339+ coro:
340+ The coroutine to run.
341+
342+ Returns:
343+ The result of the coroutine.
344+
345+ Raises:
346+ RuntimeError:
347+ If the detected runtime is not installed.
348+ """
349+ runtime = _detect_async_runtime_from_coroutine (coro )
350+
351+ if runtime == "trio" :
352+ if trio is None :
353+ msg = "trio is not installed"
354+ raise RuntimeError (msg )
355+
356+ if current_trio_token is not None :
357+ try :
358+ token = current_trio_token ()
359+
360+ async def _run_with_token () -> _T :
361+ return await coro
362+
363+ return trio .from_thread .run_sync (_run_with_token , trio_token = token ) # type: ignore[return-value]
364+ except RuntimeError :
365+ pass
366+
367+ ctx = contextvars .copy_context ()
368+
369+ async def _run_trio () -> _T :
370+ return await coro
371+
372+ return ctx .run (trio .run , _run_trio )
373+
374+ if runtime == "curio" :
375+ if curio is None :
376+ msg = "curio is not installed"
377+ raise RuntimeError (msg )
378+
379+ try :
380+ return curio .AWAIT (coro )
381+ except RuntimeError :
382+ ctx = contextvars .copy_context ()
383+
384+ async def _run_curio () -> _T :
385+ return await coro
386+
387+ return ctx .run (curio .run , _run_curio )
388+
389+ try :
390+ loop = asyncio .get_running_loop ()
391+ except RuntimeError :
392+ loop = None
393+
394+ if loop is not None :
395+ future : asyncio .Future [_T ] = asyncio .run_coroutine_threadsafe (coro , loop ) # type: ignore[arg-type,assignment]
396+ return future .result ()
397+
398+ ctx = contextvars .copy_context ()
399+ return ctx .run (asyncio .run , coro ) # type: ignore[arg-type,return-value]
400+
401+
402+ def _detect_async_runtime_from_coroutine (coro : Coroutine [Any , Any , _T ]) -> str :
403+ """
404+ Detect async runtime by inspecting the coroutine object.
405+
406+ Args:
407+ coro:
408+ The coroutine object to inspect.
409+
410+ Returns:
411+ The detected runtime: "asyncio", "trio", or "curio".
412+ """
413+ if sniffio is not None :
414+ try :
415+ return sniffio .current_async_library ()
416+ except sniffio .AsyncLibraryNotFoundError :
417+ pass
418+
419+ # Inspect the coroutine's code and globals to determine runtime
420+ func_code = coro .cr_code # type: ignore[attr-defined]
421+ code_names = set (func_code .co_names )
422+
423+ # Check if trio or curio modules are actually used in the code
424+ # If the module name is in co_names, it's being accessed (e.g., trio.sleep)
425+ # We also verify with specific indicators to avoid false positives
426+ if "trio" in code_names :
427+ trio_indicators = {"open_nursery" , "current_trio_token" , "CancelScope" , "sleep" }
428+ if any (indicator in code_names for indicator in trio_indicators ):
429+ return "trio"
430+
431+ if "curio" in code_names :
432+ curio_indicators = {"TaskGroup" , "spawn" , "AWAIT" , "sleep" }
433+ if any (indicator in code_names for indicator in curio_indicators ):
434+ return "curio"
435+
436+ # Default to asyncio as it's the most common
437+ return "asyncio"
0 commit comments