99
1010from __future__ import annotations
1111
12+ import asyncio
13+ import contextvars
1214import inspect
1315import logging
1416import socket
1820from inspect import Parameter , _ParameterKind
1921from typing import TYPE_CHECKING , TypeVar
2022
23+ try :
24+ import sniffio
25+ except ImportError :
26+ sniffio = None # type: ignore[assignment]
27+
28+ try :
29+ import trio
30+ from trio .lowlevel import current_trio_token
31+ except ImportError :
32+ trio = None # type: ignore[assignment]
33+ current_trio_token = None # type: ignore[assignment]
34+
35+ try :
36+ import curio
37+ except (ImportError , AttributeError ):
38+ curio = None # type: ignore[assignment]
39+
2140if TYPE_CHECKING :
2241 from collections .abc import Callable , Mapping
2342
@@ -179,7 +198,7 @@ def find_free_port() -> int:
179198 return s .getsockname ()[1 ]
180199
181200
182- def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901
201+ def apply_args (f : Callable [..., _T ], args : Mapping [str , object ]) -> _T : # noqa: C901, PLR0912
183202 """
184203 Apply arguments to a function.
185204
@@ -188,6 +207,9 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
188207 it is possible to pass arguments by name, and falling back to positional
189208 arguments if not.
190209
210+ This function supports both synchronous and asynchronous callables. If the
211+ callable is a coroutine function, it will be executed in an event loop.
212+
191213 Args:
192214 f:
193215 The function to apply the arguments to.
@@ -200,6 +222,7 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
200222 Returns:
201223 The result of the function.
202224 """
225+ is_async = inspect .iscoroutinefunction (f )
203226 signature = inspect .signature (f )
204227 f_name = (
205228 f .__qualname__
@@ -226,7 +249,15 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
226249 # First, we inspect the keyword arguments and try and pass in some arguments
227250 # by currying them in.
228251 for param in signature .parameters .values ():
229- if param .name not in args :
252+ # Try matching the parameter name, or if it starts with underscore,
253+ # also try matching without the leading underscore.
254+ arg_key = None
255+ if param .name in args :
256+ arg_key = param .name
257+ elif param .name .startswith ("_" ) and param .name [1 :] in args :
258+ arg_key = param .name [1 :]
259+
260+ if arg_key is None :
230261 # If a parameter is not known, we will ignore it.
231262 #
232263 # If the ignored parameter doesn't have a default value, it will
@@ -246,12 +277,12 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
246277 if param .kind in positional_match :
247278 # We iterate through the parameters in order that they are defined,
248279 # making it fine to pass them in by position one at a time.
249- f = partial (f , args [param . name ])
250- del args [param . name ]
280+ f = partial (f , args [arg_key ])
281+ del args [arg_key ]
251282
252283 if param .kind in keyword_match :
253- f = partial (f , ** {param .name : args [param . name ]})
254- del args [param . name ]
284+ f = partial (f , ** {param .name : args [arg_key ]})
285+ del args [arg_key ]
255286 continue
256287
257288 # At this stage, we have checked all arguments. If we have any arguments
@@ -282,7 +313,107 @@ def apply_args(f: Callable[..., _T], args: Mapping[str, object]) -> _T: # noqa:
282313 )
283314
284315 try :
316+ if is_async :
317+ return _run_async (f )
285318 return f ()
286319 except Exception :
287320 logger .exception ("Error occurred while calling function %s" , f_name )
288321 raise
322+
323+
324+ def _run_async (coro_func : Callable [..., _T ]) -> _T :
325+ """
326+ Run an async function in an event loop.
327+
328+ Detects the async runtime (asyncio, trio, or curio) and executes the
329+ coroutine appropriately. Preserves ContextVars when creating a new event
330+ loop, which is important when handlers are called from threads.
331+
332+ Args:
333+ coro_func:
334+ The async function to run (already partially applied).
335+
336+ Returns:
337+ The result of the async function.
338+
339+ Raises:
340+ RuntimeError:
341+ If the detected runtime is not installed.
342+ """
343+ runtime = _detect_async_runtime_from_coro (coro_func )
344+
345+ if runtime == "trio" :
346+ if trio is None :
347+ msg = "trio is not installed"
348+ raise RuntimeError (msg )
349+
350+ if current_trio_token is not None :
351+ try :
352+ token = current_trio_token ()
353+ return trio .from_thread .run_sync (coro_func , trio_token = token )
354+ except RuntimeError :
355+ pass
356+
357+ ctx = contextvars .copy_context ()
358+ return ctx .run (trio .run , coro_func )
359+
360+ if runtime == "curio" :
361+ if curio is None :
362+ msg = "curio is not installed"
363+ raise RuntimeError (msg )
364+
365+ try :
366+ return curio .AWAIT (coro_func ())
367+ except RuntimeError :
368+ ctx = contextvars .copy_context ()
369+ return ctx .run (curio .run , coro_func )
370+
371+ try :
372+ loop = asyncio .get_running_loop ()
373+ except RuntimeError :
374+ loop = None
375+
376+ if loop is not None :
377+ future = asyncio .run_coroutine_threadsafe (coro_func (), loop )
378+ return future .result ()
379+
380+ ctx = contextvars .copy_context ()
381+ return ctx .run (asyncio .run , coro_func ())
382+
383+
384+ def _detect_async_runtime_from_coro (coro_func : Callable [..., _T ]) -> str :
385+ """
386+ Detect async runtime by inspecting the coroutine function.
387+
388+ Args:
389+ coro_func:
390+ The async function to inspect. May be a partial object.
391+
392+ Returns:
393+ The detected runtime: "asyncio", "trio", or "curio".
394+ """
395+ if sniffio is not None :
396+ try :
397+ return sniffio .current_async_library ()
398+ except sniffio .AsyncLibraryNotFoundError :
399+ pass
400+
401+ func = getattr (coro_func , "func" , coro_func )
402+ func_code = getattr (func , "__code__" , None )
403+ func_globals = getattr (func , "__globals__" , {})
404+
405+ if func_code is not None :
406+ code_names = set (func_code .co_names )
407+
408+ trio_indicators = {"trio" , "open_nursery" , "current_trio_token" }
409+ curio_indicators = {"curio" , "TaskGroup" , "AWAIT" }
410+
411+ uses_trio = bool (trio_indicators & code_names )
412+ uses_curio = bool (curio_indicators & code_names )
413+
414+ if uses_trio and "trio" in func_globals :
415+ return "trio"
416+ if uses_curio and "curio" in func_globals :
417+ return "curio"
418+
419+ return "asyncio"
0 commit comments