99
1010from __future__ import annotations
1111
12+ import asyncio
13+ import contextvars
1214import inspect
1315import logging
1416import socket
1517import warnings
1618from contextlib import closing
1719from functools import partial
1820from inspect import Parameter , _ParameterKind
19- from typing import TYPE_CHECKING , TypeVar
21+ from typing import TYPE_CHECKING , Any , TypeVar
2022
2123if TYPE_CHECKING :
22- from collections .abc import Callable , Mapping
24+ from collections .abc import Callable , Coroutine , Mapping
25+
26+ try :
27+ import sniffio # type: ignore[import-not-found]
28+ except ImportError :
29+ sniffio = None # type: ignore[assignment]
30+
31+ try :
32+ import trio # type: ignore[import-not-found]
33+ from trio .lowlevel import current_trio_token # type: ignore[import-not-found]
34+ except ImportError :
35+ trio = None # type: ignore[assignment]
36+ current_trio_token = None # type: ignore[assignment]
37+
38+ try :
39+ import curio # type: ignore[import-not-found,import-untyped]
40+ except (ImportError , AttributeError ):
41+ curio = None # type: ignore[assignment]
2342
2443logger = logging .getLogger (__name__ )
2544
@@ -179,7 +198,7 @@ def find_free_port() -> int:
179198 return s .getsockname ()[1 ]
180199
181200
182- def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901
201+ def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901, PLR0912
183202 """
184203 Apply arguments to a function.
185204
@@ -188,6 +207,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188207 it is possible to pass arguments by name, and falling back to positional
189208 arguments if not.
190209
210+ This function supports both synchronous and asynchronous callables. If the
211+ callable is a coroutine function, it will be executed in an event loop.
212+
191213 Args:
192214 f:
193215 The function to apply the arguments to.
@@ -200,6 +222,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200222 Returns:
201223 The result of the function.
202224 """
225+ # Check if f is a partial wrapping an async function
226+ func_to_check = f .func if isinstance (f , partial ) else f
227+ is_async = inspect .iscoroutinefunction (func_to_check )
203228 signature = inspect .signature (f )
204229 f_name = (
205230 f .__qualname__
@@ -226,7 +251,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226251 # First, we inspect the keyword arguments and try and pass in some arguments
227252 # by currying them in.
228253 for param in signature .parameters .values ():
229- if param .name not in args :
254+ # Try matching the parameter name, or if it starts with underscore,
255+ # also try matching without the leading underscore.
256+ arg_key = None
257+ if param .name in args :
258+ arg_key = param .name
259+ elif (
260+ param .name .startswith ("_" )
261+ and len (param .name ) > 1
262+ and param .name [1 :] in args
263+ ):
264+ arg_key = param .name [1 :]
265+
266+ if arg_key is None :
230267 # If a parameter is not known, we will ignore it.
231268 #
232269 # If the ignored parameter doesn't have a default value, it will
@@ -246,12 +283,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246283 if param .kind in positional_match :
247284 # We iterate through the parameters in order that they are defined,
248285 # making it fine to pass them in by position one at a time.
249- f = partial (f , args [param . name ])
250- del args [param . name ]
286+ f = partial (f , args [arg_key ])
287+ del args [arg_key ]
251288
252289 if param .kind in keyword_match :
253- f = partial (f , ** {param .name : args [param . name ]})
254- del args [param . name ]
290+ f = partial (f , ** {param .name : args [arg_key ]})
291+ del args [arg_key ]
255292 continue
256293
257294 # At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +319,123 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282319 )
283320
284321 try :
322+ if is_async :
323+ result = f ()
324+ if inspect .iscoroutine (result ):
325+ return _run_async_coroutine (result )
326+ return result
285327 return f ()
286328 except Exception :
287329 logger .exception ("Error occurred while calling function %s" , f_name )
288330 raise
331+
332+
333+ def _run_async_coroutine (coro : Coroutine [Any , Any , _T ]) -> _T : # noqa: C901
334+ """
335+ Run a coroutine in an event loop.
336+
337+ Detects the async runtime (asyncio, trio, or curio) and executes the
338+ coroutine appropriately. Preserves ContextVars when creating a new event
339+ loop, which is important when handlers are called from threads.
340+
341+ Args:
342+ coro:
343+ The coroutine to run.
344+
345+ Returns:
346+ The result of the coroutine.
347+
348+ Raises:
349+ RuntimeError:
350+ If the detected runtime (trio or curio) is not installed.
351+ """
352+ runtime = _detect_async_runtime_from_coroutine (coro )
353+
354+ if runtime == "trio" :
355+ if trio is None :
356+ msg = "trio is not installed"
357+ raise RuntimeError (msg )
358+
359+ if current_trio_token is not None :
360+ try :
361+ token = current_trio_token ()
362+
363+ async def _run_with_token () -> _T :
364+ return await coro
365+
366+ return trio .from_thread .run_sync (_run_with_token , trio_token = token ) # type: ignore[return-value]
367+ except RuntimeError :
368+ pass
369+
370+ ctx = contextvars .copy_context ()
371+
372+ async def _run_trio () -> _T :
373+ return await coro
374+
375+ return ctx .run (trio .run , _run_trio )
376+
377+ if runtime == "curio" :
378+ if curio is None :
379+ msg = "curio is not installed"
380+ raise RuntimeError (msg )
381+
382+ try :
383+ return curio .AWAIT (coro )
384+ except RuntimeError :
385+ ctx = contextvars .copy_context ()
386+
387+ async def _run_curio () -> _T :
388+ return await coro
389+
390+ return ctx .run (curio .run , _run_curio )
391+
392+ try :
393+ loop = asyncio .get_running_loop ()
394+ except RuntimeError :
395+ loop = None
396+
397+ if loop is not None :
398+ future : asyncio .Future [_T ] = asyncio .run_coroutine_threadsafe (coro , loop ) # type: ignore[arg-type,assignment]
399+ return future .result ()
400+
401+ ctx = contextvars .copy_context ()
402+ return ctx .run (asyncio .run , coro ) # type: ignore[arg-type,return-value]
403+
404+
405+ def _detect_async_runtime_from_coroutine (coro : Coroutine [Any , Any , _T ]) -> str :
406+ """
407+ Detect async runtime by inspecting the coroutine object.
408+
409+ Args:
410+ coro:
411+ The coroutine object to inspect.
412+
413+ Returns:
414+ The detected runtime: "asyncio", "trio", or "curio".
415+ """
416+ if sniffio is not None :
417+ try :
418+ return sniffio .current_async_library ()
419+ except sniffio .AsyncLibraryNotFoundError :
420+ pass
421+
422+ # Inspect the coroutine's code and globals to determine runtime
423+ func_code = coro .cr_code # type: ignore[attr-defined]
424+ code_names = set (func_code .co_names )
425+
426+ # Check if trio or curio modules are actually used in the code
427+ # If the module name is in co_names, it's being accessed (e.g., trio.sleep)
428+ # We also verify with specific indicators to avoid false positives
429+ # Note: If both trio and curio indicators are found, trio takes precedence
430+ if "trio" in code_names :
431+ trio_indicators = {"open_nursery" , "current_trio_token" , "CancelScope" , "sleep" }
432+ if any (indicator in code_names for indicator in trio_indicators ):
433+ return "trio"
434+
435+ if "curio" in code_names :
436+ curio_indicators = {"TaskGroup" , "spawn" , "AWAIT" , "sleep" }
437+ if any (indicator in code_names for indicator in curio_indicators ):
438+ return "curio"
439+
440+ # Default to asyncio as it's the most common
441+ return "asyncio"
0 commit comments