Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "amrita_core"
version = "0.8.5"
version = "0.8.6"
description = "High performance, lightweight agent framework."
readme = "README.md"
requires-python = ">=3.10,<3.15"
Expand Down
72 changes: 35 additions & 37 deletions src/amrita_core/hook/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,19 +364,15 @@ async def _do_runtime_resolve(
async def _simple_run(
cls,
matcher_list: list[FunctionData],
event: BaseEvent,
/,
exception_ignored: tuple[type[BaseException], ...],
extra_args: tuple,
extra_args: Iterable[Any],
extra_kwargs: dict[str, Any],
config: AmritaConfig | None = None,
) -> bool:
"""Run a round of matcher

Args:
matcher_list (list[FunctionData]): Matchers to run
event (BaseEvent): event
config (AmritaConfig, optional): Config
exception_ignored (tuple[type[BaseException], ...]): Exceptions to ignore(to raise again)
extra_args (tuple): extra args for dependency injection
extra_kwargs (dict[str, Any]): extra kwargs for dependency injection
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Expand All @@ -396,36 +392,18 @@ async def _simple_run(
line_number: int = frame.f_lineno
file_name: str = frame.f_code.co_filename
handler = func.function
session_args = [matcher, event, *extra_args] + (
[config] if config else []
)
session_kwargs: dict[str, Any] = deepcopy(extra_kwargs)
runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory
k: v
for k, v in enumerate(session_args)
if isinstance(v, DependsFactory)
}
runtime_kwargs = {
k: v
for k, v in session_kwargs.items()
if isinstance(v, DependsFactory)
}
# These args/kwargs will be generated by Depends
if runtime_args or runtime_kwargs:
if not await cls._do_runtime_resolve(
runtime_args=runtime_args,
runtime_kwargs=runtime_kwargs,
args2update=session_args,
kwargs2update=session_kwargs,
session_args=session_args,
session_kwargs=session_kwargs,
exception_ignored=exception_ignored,
):
raise RuntimeError("Runtime arguments cannot be resolved")

if any(isinstance(i, DependsFactory) for i in extra_args):
raise ValueError(
"Runtime dependency injection is not supported in simple_run, please resolve them first or pass it to the trigger_event method"
)
elif any(isinstance(i, DependsFactory) for i in extra_kwargs.values()):
raise ValueError(
"Runtime dependency injection is not supported in simple_run, please resolve them first or pass it to the trigger_event method"
)
session_args = [matcher, *extra_args]
success, new_args, f_kwargs, d_kw = (
MatcherFactory._resolve_dependencies(
signature, session_args, session_kwargs
signature, session_args, extra_kwargs
)
)
if not success:
Expand All @@ -449,7 +427,7 @@ async def _simple_run(
args2update=[],
kwargs2update=f_kwargs,
session_args=session_args,
session_kwargs=session_kwargs,
session_kwargs=extra_kwargs,
exception_ignored=exception_ignored,
):
continue
Expand Down Expand Up @@ -564,15 +542,35 @@ async def trigger_event(
debug_log(f"Running matchers for event: {event_type}!")
# Check if there are handlers for this event type
if priorities:
s_args = [event, *args] + ([config] if config else [])
session_kwargs: dict[str, Any] = deepcopy(kwargs)
runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory
k: v for k, v in enumerate(s_args) if isinstance(v, DependsFactory)
}
runtime_kwargs = {
k: v
for k, v in session_kwargs.items()
if isinstance(v, DependsFactory)
}
# These args/kwargs will be generated by Depends
if runtime_args or runtime_kwargs:
if not await cls._do_runtime_resolve(
runtime_args=runtime_args,
runtime_kwargs=runtime_kwargs,
args2update=s_args,
kwargs2update=session_kwargs,
session_args=s_args,
session_kwargs=session_kwargs,
exception_ignored=exception_ignored,
):
raise RuntimeError("Runtime arguments cannot be resolved")
for priority in priorities:
logger.info(f"Running matchers for priority {priority}......")
if not await cls._simple_run(
handlers[priority],
event,
exception_ignored=exception_ignored,
extra_args=args,
extra_args=s_args,
extra_kwargs=session_kwargs,
config=config,
):
break
else:
Expand Down
79 changes: 10 additions & 69 deletions tests/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,44 +366,15 @@ async def test_handler(event: TestEvent, config: AmritaConfig):
event = TestEvent()
handlers = EventRegistry().get_handlers("test_event")

result = await MatcherFactory._simple_run(
handlers[matcher.priority], event, (), (), {}, self.config
)

assert result is True
assert call_count == 1

@pytest.mark.asyncio
async def test_simple_run_with_runtime_deps(self):
"""Test _simple_run with runtime dependencies from hook_kwargs."""
received_deps = {}

async def dep_func() -> str:
return "runtime_value"

matcher = Matcher("test_event", block=False) # Set block=False

@matcher.handle()
async def test_handler(
event: TestEvent, config: AmritaConfig, runtime_dep: str
):
nonlocal received_deps
received_deps["runtime_dep"] = runtime_dep

event = TestEvent()
handlers = EventRegistry().get_handlers("test_event")

result = await MatcherFactory._simple_run(
handlers[matcher.priority],
event,
(),
(),
{"runtime_dep": Depends(dep_func)},
self.config,
exception_ignored=(),
extra_args=(event, self.config),
extra_kwargs={},
)

assert result is True
assert received_deps["runtime_dep"] == "runtime_value"
assert call_count == 1

@pytest.mark.asyncio
async def test_simple_run_with_default_deps(self):
Expand All @@ -429,42 +400,14 @@ async def test_handler(

result = await MatcherFactory._simple_run(
handlers[matcher.priority],
event,
(),
(),
{},
self.config,
exception_ignored=(),
extra_args=(event, self.config),
extra_kwargs={},
)

assert result is True
assert received_deps["default_dep_param"] == "default_value"

@pytest.mark.asyncio
async def test_simple_run_runtime_deps_failure(self):
"""Test _simple_run when runtime deps fail to resolve."""

async def failing_dep() -> str | None:
return None

matcher = Matcher("test_event", block=False) # Set block=False

@matcher.handle()
async def test_handler(event: TestEvent, config: AmritaConfig, bad_dep: str):
pass

event = TestEvent()
handlers = EventRegistry().get_handlers("test_event")

with pytest.raises(RuntimeError, match="Runtime arguments cannot be resolved"):
await MatcherFactory._simple_run(
handlers[matcher.priority],
event,
(),
(),
{"bad_dep": Depends(failing_dep)},
self.config,
)

@pytest.mark.asyncio
async def test_simple_run_default_deps_failure(self):
"""Test _simple_run when default deps fail to resolve (should skip handler)."""
Expand All @@ -487,11 +430,9 @@ async def test_handler(

result = await MatcherFactory._simple_run(
handlers[matcher.priority],
event,
(),
(),
{},
self.config,
exception_ignored=(),
extra_args=(event, self.config),
extra_kwargs={},
)

assert result is True # Should continue to next handler
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading