99
1010from __future__ import annotations
1111
12+ import asyncio
13+ import contextvars
1214import inspect
1315import logging
1416import socket
1517import warnings
1618from contextlib import closing
1719from functools import partial
1820from inspect import Parameter , _ParameterKind
19- from typing import TYPE_CHECKING , TypeVar
21+ from typing import TYPE_CHECKING , Any , TypeVar
2022
2123if TYPE_CHECKING :
22- from collections .abc import Callable , Mapping
24+ from collections .abc import Callable , Coroutine , Mapping
25+
26+ try :
27+ import sniffio # type: ignore[import-not-found]
28+ except ImportError :
29+ sniffio = None # type: ignore[assignment]
30+
31+ try :
32+ import trio # type: ignore[import-not-found]
33+ from trio .lowlevel import current_trio_token # type: ignore[import-not-found]
34+ except ImportError :
35+ trio = None # type: ignore[assignment]
36+ current_trio_token = None # type: ignore[assignment]
37+
38+ try :
39+ import curio # type: ignore[import-not-found,import-untyped]
40+ except (ImportError , AttributeError ):
41+ curio = None # type: ignore[assignment]
42+
43+ if TYPE_CHECKING :
44+ from collections .abc import Callable , Coroutine , Mapping
2345
2446logger = logging .getLogger (__name__ )
2547
@@ -179,7 +201,7 @@ def find_free_port() -> int:
179201 return s .getsockname ()[1 ]
180202
181203
182- def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901
204+ def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901, PLR0912
183205 """
184206 Apply arguments to a function.
185207
@@ -188,6 +210,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188210 it is possible to pass arguments by name, and falling back to positional
189211 arguments if not.
190212
213+ This function supports both synchronous and asynchronous callables. If the
214+ callable is a coroutine function, it will be executed in an event loop.
215+
191216 Args:
192217 f:
193218 The function to apply the arguments to.
@@ -200,6 +225,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200225 Returns:
201226 The result of the function.
202227 """
228+ # Check if f is a partial wrapping an async function
229+ func_to_check = f .func if isinstance (f , partial ) else f
230+ is_async = inspect .iscoroutinefunction (func_to_check )
203231 signature = inspect .signature (f )
204232 f_name = (
205233 f .__qualname__
@@ -226,7 +254,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226254 # First, we inspect the keyword arguments and try and pass in some arguments
227255 # by currying them in.
228256 for param in signature .parameters .values ():
229- if param .name not in args :
257+ # Try matching the parameter name, or if it starts with underscore,
258+ # also try matching without the leading underscore.
259+ arg_key = None
260+ if param .name in args :
261+ arg_key = param .name
262+ elif (
263+ param .name .startswith ("_" )
264+ and len (param .name ) > 1
265+ and param .name [1 :] in args
266+ ):
267+ arg_key = param .name [1 :]
268+
269+ if arg_key is None :
230270 # If a parameter is not known, we will ignore it.
231271 #
232272 # If the ignored parameter doesn't have a default value, it will
@@ -246,12 +286,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246286 if param .kind in positional_match :
247287 # We iterate through the parameters in order that they are defined,
248288 # making it fine to pass them in by position one at a time.
249- f = partial (f , args [param . name ])
250- del args [param . name ]
289+ f = partial (f , args [arg_key ])
290+ del args [arg_key ]
251291
252292 if param .kind in keyword_match :
253- f = partial (f , ** {param .name : args [param . name ]})
254- del args [param . name ]
293+ f = partial (f , ** {param .name : args [arg_key ]})
294+ del args [arg_key ]
255295 continue
256296
257297 # At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +322,125 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282322 )
283323
284324 try :
325+ if is_async :
326+ result = f ()
327+ if inspect .iscoroutine (result ):
328+ return _run_async_coroutine (result )
329+ return result
285330 return f ()
286331 except Exception :
287332 logger .exception ("Error occurred while calling function %s" , f_name )
288333 raise
334+
335+
336+ def _run_async_coroutine (coro : Coroutine [Any , Any , _T ]) -> _T : # noqa: C901
337+ """
338+ Run a coroutine in an event loop.
339+
340+ Detects the async runtime (asyncio, trio, or curio) and executes the
341+ coroutine appropriately. Preserves ContextVars when creating a new event
342+ loop, which is important when handlers are called from threads.
343+
344+ Args:
345+ coro:
346+ The coroutine to run.
347+
348+ Returns:
349+ The result of the coroutine.
350+
351+ Raises:
352+ RuntimeError:
353+ If the detected async runtime (trio or curio) is not installed.
354+ RuntimeError:
355+ If the detected runtime is not installed.
356+ """
357+ runtime = _detect_async_runtime_from_coroutine (coro )
358+
359+ if runtime == "trio" :
360+ if trio is None :
361+ msg = "trio is not installed"
362+ raise RuntimeError (msg )
363+
364+ if current_trio_token is not None :
365+ try :
366+ token = current_trio_token ()
367+
368+ async def _run_with_token () -> _T :
369+ return await coro
370+
371+ return trio .from_thread .run_sync (_run_with_token , trio_token = token ) # type: ignore[return-value]
372+ except RuntimeError :
373+ pass
374+
375+ ctx = contextvars .copy_context ()
376+
377+ async def _run_trio () -> _T :
378+ return await coro
379+
380+ return ctx .run (trio .run , _run_trio )
381+
382+ if runtime == "curio" :
383+ if curio is None :
384+ msg = "curio is not installed"
385+ raise RuntimeError (msg )
386+
387+ try :
388+ return curio .AWAIT (coro )
389+ except RuntimeError :
390+ ctx = contextvars .copy_context ()
391+
392+ async def _run_curio () -> _T :
393+ return await coro
394+
395+ return ctx .run (curio .run , _run_curio )
396+
397+ try :
398+ loop = asyncio .get_running_loop ()
399+ except RuntimeError :
400+ loop = None
401+
402+ if loop is not None :
403+ future : asyncio .Future [_T ] = asyncio .run_coroutine_threadsafe (coro , loop ) # type: ignore[arg-type,assignment]
404+ return future .result ()
405+
406+ ctx = contextvars .copy_context ()
407+ return ctx .run (asyncio .run , coro ) # type: ignore[arg-type,return-value]
408+
409+
410+ def _detect_async_runtime_from_coroutine (coro : Coroutine [Any , Any , _T ]) -> str :
411+ """
412+ Detect async runtime by inspecting the coroutine object.
413+
414+ Args:
415+ coro:
416+ The coroutine object to inspect.
417+
418+ Returns:
419+ The detected runtime: "asyncio", "trio", or "curio".
420+ """
421+ if sniffio is not None :
422+ try :
423+ return sniffio .current_async_library ()
424+ except sniffio .AsyncLibraryNotFoundError :
425+ pass
426+
427+ # Inspect the coroutine's code and globals to determine runtime
428+ func_code = coro .cr_code # type: ignore[attr-defined]
429+ code_names = set (func_code .co_names )
430+
431+ # Check if trio or curio modules are actually used in the code
432+ # If the module name is in co_names, it's being accessed (e.g., trio.sleep)
433+ # We also verify with specific indicators to avoid false positives
434+ # Note: If both trio and curio indicators are found, trio takes precedence
435+ if "trio" in code_names :
436+ trio_indicators = {"open_nursery" , "current_trio_token" , "CancelScope" , "sleep" }
437+ if any (indicator in code_names for indicator in trio_indicators ):
438+ return "trio"
439+
440+ if "curio" in code_names :
441+ curio_indicators = {"TaskGroup" , "spawn" , "AWAIT" , "sleep" }
442+ if any (indicator in code_names for indicator in curio_indicators ):
443+ return "curio"
444+
445+ # Default to asyncio as it's the most common
446+ return "asyncio"
0 commit comments