diff --git a/pyproject.toml b/pyproject.toml index 8c50d8d..b2f36f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "amrita_core" -version = "0.8.5" +version = "0.8.6" description = "High performance, lightweight agent framework." readme = "README.md" requires-python = ">=3.10,<3.15" diff --git a/src/amrita_core/hook/matcher.py b/src/amrita_core/hook/matcher.py index 9570fb2..0b93d29 100644 --- a/src/amrita_core/hook/matcher.py +++ b/src/amrita_core/hook/matcher.py @@ -364,19 +364,15 @@ async def _do_runtime_resolve( async def _simple_run( cls, matcher_list: list[FunctionData], - event: BaseEvent, /, exception_ignored: tuple[type[BaseException], ...], - extra_args: tuple, + extra_args: Iterable[Any], extra_kwargs: dict[str, Any], - config: AmritaConfig | None = None, ) -> bool: """Run a round of matcher Args: matcher_list (list[FunctionData]): Matchers to run - event (BaseEvent): event - config (AmritaConfig, optional): Config exception_ignored (tuple[type[BaseException], ...]): Exceptions to ignore(to raise again) extra_args (tuple): extra args for dependency injection extra_kwargs (dict[str, Any]): extra kwargs for dependency injection @@ -396,36 +392,18 @@ async def _simple_run( line_number: int = frame.f_lineno file_name: str = frame.f_code.co_filename handler = func.function - session_args = [matcher, event, *extra_args] + ( - [config] if config else [] - ) - session_kwargs: dict[str, Any] = deepcopy(extra_kwargs) - runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory - k: v - for k, v in enumerate(session_args) - if isinstance(v, DependsFactory) - } - runtime_kwargs = { - k: v - for k, v in session_kwargs.items() - if isinstance(v, DependsFactory) - } - # These args/kwargs will be generated by Depends - if runtime_args or runtime_kwargs: - if not await cls._do_runtime_resolve( - runtime_args=runtime_args, - runtime_kwargs=runtime_kwargs, - args2update=session_args, - kwargs2update=session_kwargs, - session_args=session_args, - session_kwargs=session_kwargs, - exception_ignored=exception_ignored, - ): - raise RuntimeError("Runtime arguments cannot be resolved") - + if any(isinstance(i, DependsFactory) for i in extra_args): + raise ValueError( + "Runtime dependency injection is not supported in simple_run, please resolve them first or pass it to the trigger_event method" + ) + elif any(isinstance(i, DependsFactory) for i in extra_kwargs.values()): + raise ValueError( + "Runtime dependency injection is not supported in simple_run, please resolve them first or pass it to the trigger_event method" + ) + session_args = [matcher, *extra_args] success, new_args, f_kwargs, d_kw = ( MatcherFactory._resolve_dependencies( - signature, session_args, session_kwargs + signature, session_args, extra_kwargs ) ) if not success: @@ -449,7 +427,7 @@ async def _simple_run( args2update=[], kwargs2update=f_kwargs, session_args=session_args, - session_kwargs=session_kwargs, + session_kwargs=extra_kwargs, exception_ignored=exception_ignored, ): continue @@ -564,15 +542,35 @@ async def trigger_event( debug_log(f"Running matchers for event: {event_type}!") # Check if there are handlers for this event type if priorities: + s_args = [event, *args] + ([config] if config else []) + session_kwargs: dict[str, Any] = deepcopy(kwargs) + runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory + k: v for k, v in enumerate(s_args) if isinstance(v, DependsFactory) + } + runtime_kwargs = { + k: v + for k, v in session_kwargs.items() + if isinstance(v, DependsFactory) + } + # These args/kwargs will be generated by Depends + if runtime_args or runtime_kwargs: + if not await cls._do_runtime_resolve( + runtime_args=runtime_args, + runtime_kwargs=runtime_kwargs, + args2update=s_args, + kwargs2update=session_kwargs, + session_args=s_args, + session_kwargs=session_kwargs, + exception_ignored=exception_ignored, + ): + raise RuntimeError("Runtime arguments cannot be resolved") for priority in priorities: logger.info(f"Running matchers for priority {priority}......") if not await cls._simple_run( handlers[priority], - event, exception_ignored=exception_ignored, - extra_args=args, + extra_args=s_args, extra_kwargs=session_kwargs, - config=config, ): break else: diff --git a/tests/test_matcher.py b/tests/test_matcher.py index 10d3c31..38b4d5f 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -366,44 +366,15 @@ async def test_handler(event: TestEvent, config: AmritaConfig): event = TestEvent() handlers = EventRegistry().get_handlers("test_event") - result = await MatcherFactory._simple_run( - handlers[matcher.priority], event, (), (), {}, self.config - ) - - assert result is True - assert call_count == 1 - - @pytest.mark.asyncio - async def test_simple_run_with_runtime_deps(self): - """Test _simple_run with runtime dependencies from hook_kwargs.""" - received_deps = {} - - async def dep_func() -> str: - return "runtime_value" - - matcher = Matcher("test_event", block=False) # Set block=False - - @matcher.handle() - async def test_handler( - event: TestEvent, config: AmritaConfig, runtime_dep: str - ): - nonlocal received_deps - received_deps["runtime_dep"] = runtime_dep - - event = TestEvent() - handlers = EventRegistry().get_handlers("test_event") - result = await MatcherFactory._simple_run( handlers[matcher.priority], - event, - (), - (), - {"runtime_dep": Depends(dep_func)}, - self.config, + exception_ignored=(), + extra_args=(event, self.config), + extra_kwargs={}, ) assert result is True - assert received_deps["runtime_dep"] == "runtime_value" + assert call_count == 1 @pytest.mark.asyncio async def test_simple_run_with_default_deps(self): @@ -429,42 +400,14 @@ async def test_handler( result = await MatcherFactory._simple_run( handlers[matcher.priority], - event, - (), - (), - {}, - self.config, + exception_ignored=(), + extra_args=(event, self.config), + extra_kwargs={}, ) assert result is True assert received_deps["default_dep_param"] == "default_value" - @pytest.mark.asyncio - async def test_simple_run_runtime_deps_failure(self): - """Test _simple_run when runtime deps fail to resolve.""" - - async def failing_dep() -> str | None: - return None - - matcher = Matcher("test_event", block=False) # Set block=False - - @matcher.handle() - async def test_handler(event: TestEvent, config: AmritaConfig, bad_dep: str): - pass - - event = TestEvent() - handlers = EventRegistry().get_handlers("test_event") - - with pytest.raises(RuntimeError, match="Runtime arguments cannot be resolved"): - await MatcherFactory._simple_run( - handlers[matcher.priority], - event, - (), - (), - {"bad_dep": Depends(failing_dep)}, - self.config, - ) - @pytest.mark.asyncio async def test_simple_run_default_deps_failure(self): """Test _simple_run when default deps fail to resolve (should skip handler).""" @@ -487,11 +430,9 @@ async def test_handler( result = await MatcherFactory._simple_run( handlers[matcher.priority], - event, - (), - (), - {}, - self.config, + exception_ignored=(), + extra_args=(event, self.config), + extra_kwargs={}, ) assert result is True # Should continue to next handler diff --git a/uv.lock b/uv.lock index 27cab1e..d60a38a 100644 --- a/uv.lock +++ b/uv.lock @@ -181,7 +181,7 @@ wheels = [ [[package]] name = "amrita-core" -version = "0.8.5" +version = "0.8.6" source = { editable = "." } dependencies = [ { name = "aiofiles" },