99
1010from __future__ import annotations
1111
12+ import asyncio
13+ import contextvars
14+ import dis
1215import inspect
1316import logging
1417import socket
1518import warnings
1619from contextlib import closing
1720from functools import partial
1821from inspect import Parameter , _ParameterKind
19- from typing import TYPE_CHECKING , TypeVar
22+ from typing import TYPE_CHECKING , Any , TypeVar
2023
2124if TYPE_CHECKING :
22- from collections .abc import Callable , Mapping
25+ from collections .abc import Callable , Coroutine , Mapping
26+
27+ try :
28+ import sniffio # type: ignore[import-not-found]
29+ except ImportError :
30+ sniffio = None # type: ignore[assignment]
31+
32+ try :
33+ import trio # type: ignore[import-not-found]
34+ from trio .lowlevel import current_trio_token # type: ignore[import-not-found]
35+ except ImportError :
36+ trio = None # type: ignore[assignment]
37+ current_trio_token = None # type: ignore[assignment]
38+
39+ try :
40+ import curio # type: ignore[import-not-found,import-untyped]
41+ except (ImportError , AttributeError ):
42+ curio = None # type: ignore[assignment]
2343
2444logger = logging .getLogger (__name__ )
2545
@@ -179,7 +199,7 @@ def find_free_port() -> int:
179199 return s .getsockname ()[1 ]
180200
181201
182- def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901
202+ def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901, PLR0912
183203 """
184204 Apply arguments to a function.
185205
@@ -188,6 +208,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188208 it is possible to pass arguments by name, and falling back to positional
189209 arguments if not.
190210
211+ This function supports both synchronous and asynchronous callables. If the
212+ callable is a coroutine function, it will be executed in an event loop.
213+
191214 Args:
192215 f:
193216 The function to apply the arguments to.
@@ -200,6 +223,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200223 Returns:
201224 The result of the function.
202225 """
226+ # Check if f is a partial wrapping an async function
227+ func_to_check = f .func if isinstance (f , partial ) else f
228+ is_async = inspect .iscoroutinefunction (func_to_check )
203229 signature = inspect .signature (f )
204230 f_name = (
205231 f .__qualname__
@@ -226,7 +252,19 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226252 # First, we inspect the keyword arguments and try and pass in some arguments
227253 # by currying them in.
228254 for param in signature .parameters .values ():
229- if param .name not in args :
255+ # Try matching the parameter name, or if it starts with underscore,
256+ # also try matching without the leading underscore.
257+ arg_key = None
258+ if param .name in args :
259+ arg_key = param .name
260+ elif (
261+ param .name .startswith ("_" )
262+ and len (param .name ) > 1
263+ and param .name [1 :] in args
264+ ):
265+ arg_key = param .name [1 :]
266+
267+ if arg_key is None :
230268 # If a parameter is not known, we will ignore it.
231269 #
232270 # If the ignored parameter doesn't have a default value, it will
@@ -246,12 +284,13 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246284 if param .kind in positional_match :
247285 # We iterate through the parameters in order that they are defined,
248286 # making it fine to pass them in by position one at a time.
249- f = partial (f , args [param .name ])
250- del args [param .name ]
287+ f = partial (f , args [arg_key ])
288+ del args [arg_key ]
289+ continue
251290
252291 if param .kind in keyword_match :
253- f = partial (f , ** {param .name : args [param . name ]})
254- del args [param . name ]
292+ f = partial (f , ** {param .name : args [arg_key ]})
293+ del args [arg_key ]
255294 continue
256295
257296 # At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +321,147 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282321 )
283322
284323 try :
324+ if is_async :
325+ result = f ()
326+ if inspect .iscoroutine (result ):
327+ return _run_async_coroutine (result )
328+ return result
285329 return f ()
286330 except Exception :
287331 logger .exception ("Error occurred while calling function %s" , f_name )
288332 raise
333+
334+
335+ def _run_async_coroutine (coro : Coroutine [Any , Any , _T ]) -> _T : # noqa: C901
336+ """
337+ Run a coroutine in an event loop.
338+
339+ Detects the async runtime (asyncio, trio, or curio) and executes the
340+ coroutine appropriately. Preserves ContextVars when creating a new event
341+ loop, which is important when handlers are called from threads.
342+
343+ Args:
344+ coro:
345+ The coroutine to run.
346+
347+ Returns:
348+ The result of the coroutine.
349+
350+ Raises:
351+ RuntimeError:
352+ If the detected runtime (trio or curio) is not installed.
353+ """
354+ runtime = _detect_async_runtime_from_coroutine (coro )
355+
356+ if runtime == "trio" :
357+ if trio is None :
358+ msg = "trio is not installed"
359+ raise RuntimeError (msg )
360+
361+ if current_trio_token is not None :
362+ try :
363+ token = current_trio_token ()
364+
365+ async def _run_with_token () -> _T :
366+ return await coro
367+
368+ return trio .from_thread .run_sync (_run_with_token , trio_token = token ) # type: ignore[return-value]
369+ except RuntimeError :
370+ pass
371+
372+ ctx = contextvars .copy_context ()
373+
374+ async def _run_trio () -> _T :
375+ return await coro
376+
377+ return ctx .run (trio .run , _run_trio )
378+
379+ if runtime == "curio" :
380+ if curio is None :
381+ msg = "curio is not installed"
382+ raise RuntimeError (msg )
383+
384+ try :
385+ return curio .AWAIT (coro )
386+ except RuntimeError :
387+ ctx = contextvars .copy_context ()
388+
389+ async def _run_curio () -> _T :
390+ return await coro
391+
392+ return ctx .run (curio .run , _run_curio )
393+
394+ try :
395+ loop = asyncio .get_running_loop ()
396+ except RuntimeError :
397+ loop = None
398+
399+ if loop is not None :
400+ future : asyncio .Future [_T ] = asyncio .run_coroutine_threadsafe (coro , loop ) # type: ignore[arg-type,assignment]
401+ return future .result ()
402+
403+ ctx = contextvars .copy_context ()
404+ return ctx .run (asyncio .run , coro ) # type: ignore[arg-type,return-value]
405+
406+
407+ def _detect_async_runtime_from_coroutine (coro : Coroutine [Any , Any , _T ]) -> str : # noqa: C901
408+ """
409+ Detect async runtime by inspecting the coroutine object.
410+
411+ Args:
412+ coro:
413+ The coroutine object to inspect.
414+
415+ Returns:
416+ The detected runtime: "asyncio", "trio", or "curio".
417+ """
418+ if sniffio is not None :
419+ try :
420+ return sniffio .current_async_library ()
421+ except sniffio .AsyncLibraryNotFoundError :
422+ pass
423+
424+ # Inspect bytecode to check for qualified attribute access (e.g., trio.sleep)
425+ # This is more robust than just checking co_names for module and method separately
426+ func_code = coro .cr_code # type: ignore[attr-defined]
427+
428+ # Parse bytecode to find LOAD_GLOBAL/LOAD_NAME followed by LOAD_ATTR patterns
429+ # This detects qualified accesses like `trio.sleep()` or `curio.spawn()`
430+ bytecode = list (dis .get_instructions (func_code ))
431+
432+ trio_detected = False
433+ curio_detected = False
434+
435+ for i , instr in enumerate (bytecode ):
436+ # Check for module.attribute pattern (LOAD_GLOBAL/LOAD_NAME + LOAD_ATTR)
437+ if instr .opname in ("LOAD_GLOBAL" , "LOAD_NAME" ) and i + 1 < len (bytecode ):
438+ next_instr = bytecode [i + 1 ]
439+ if next_instr .opname == "LOAD_ATTR" :
440+ module_name = instr .argval
441+ attr_name = next_instr .argval
442+
443+ # Check for trio-specific qualified access
444+ if module_name == "trio" :
445+ trio_indicators = {
446+ "sleep" ,
447+ "open_nursery" ,
448+ "CancelScope" ,
449+ "current_trio_token" ,
450+ }
451+ if attr_name in trio_indicators :
452+ trio_detected = True
453+
454+ # Check for curio-specific qualified access
455+ elif module_name == "curio" :
456+ curio_indicators = {"sleep" , "spawn" , "TaskGroup" , "AWAIT" }
457+ if attr_name in curio_indicators :
458+ curio_detected = True
459+
460+ # Trio takes precedence if both are detected
461+ if trio_detected :
462+ return "trio"
463+ if curio_detected :
464+ return "curio"
465+
466+ # Default to asyncio as it's the most common
467+ return "asyncio"
0 commit comments