Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "amrita_core"
version = "0.8.3.1"
version = "0.8.4"
description = "High performance, lightweight agent framework."
readme = "README.md"
requires-python = ">=3.10,<3.15"
Expand Down
46 changes: 27 additions & 19 deletions src/amrita_core/builtins/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ async def _generate_reasoning_msg(
then: Callable[
[
Self,
ToolCall, # trigger_response
UniResponse[str, None], # tool_response
ToolCall,
UniResponse[str, None],
],
Awaitable[Any],
],
Expand Down Expand Up @@ -303,21 +303,16 @@ async def _build_stop_response_and_append(
pass

@abstractmethod
async def _append_tool_result_to_context(
self,
tool_call: ToolCall,
func_response: str,
response_msg: UniResponse[None, list[ToolCall] | None],
async def _append_reasoning(
self, tool_call: ToolCall, reasoning_content: UniResponse[str, None]
):
"""Append tool result to context (strategy-specific).
"""Append reasoning content to context (strategy-specific).

Subclasses must implement this to define how tool results are added to context.
Subclasses should use self.ctx.message to access the message list.
Subclasses must implement this to define how reasoning results are added to context.

Args:
tool_call: The tool call object
func_response: The function execution result
response_msg: The original response message
tool_call: The tool call object containing the reasoning request
reasoning_content: The response containing the generated reasoning content
"""
...

Expand Down Expand Up @@ -453,15 +448,21 @@ async def _handle_error_append(
...

@abstractmethod
async def _append_reasoning(
self, tool_call: ToolCall, reasoning_content: UniResponse[str, None]
async def _append_tool_result_to_context(
self,
tool_call: ToolCall,
func_response: str,
response_msg: UniResponse[None, list[ToolCall] | None],
):
"""Append reasoning content to context (strategy-specific).
"""Append tool result to context (strategy-specific).

Subclasses must implement this to define how reasoning results are added to context.
Subclasses must implement this to define how tool results are added to context.
Subclasses should use self.ctx.message to access the message list.

Args:
response: The response from tools_caller containing reasoning tool calls
tool_call: The tool call object
func_response: The function execution result
response_msg: The original response message
"""
...

Expand Down Expand Up @@ -942,4 +943,11 @@ def get_category(cls) -> Literal["agent-mixed"]:

AmritaAgentStrategy = ReActAgentStrategy # Alias for backward compatibility

__all__ = ["PROCESS_MESSAGE"] # backward compatibility
__all__ = [
"PROCESS_MESSAGE",
"AmritaAgentStrategy",
"BaseReActAgentStrategy",
"HybridReActAgentStrategy",
"NoActionAgentStrategy",
"ReActAgentStrategy",
] # backward compatibility
4 changes: 0 additions & 4 deletions src/amrita_core/hook/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@ class MatcherException(Exception):
"""Base exception for Matcher."""


class BlockException(MatcherException):
pass


class CancelException(MatcherException):
pass

Expand Down
93 changes: 38 additions & 55 deletions src/amrita_core/hook/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from .event import BaseEvent
from .exception import (
BlockException,
CancelException,
MatcherException,
PassException,
Expand Down Expand Up @@ -327,17 +326,18 @@ async def _simple_run(
cls,
matcher_list: list[FunctionData],
event: BaseEvent,
config: AmritaConfig,
/,
exception_ignored: tuple[type[BaseException], ...],
extra_args: tuple,
extra_kwargs: dict[str, Any],
config: AmritaConfig | None = None,
) -> bool:
"""Run a round of matcher

Args:
matcher_list (list[FunctionData]): Matchers to run
event (BaseEvent): event
config (AmritaConfig): Config
config (AmritaConfig, optional): Config
exception_ignored (tuple[type[BaseException], ...]): Exceptions to ignore(to raise again)
extra_args (tuple): extra args for dependency injection
extra_kwargs (dict[str, Any]): extra kwargs for dependency injection
Expand All @@ -346,12 +346,14 @@ async def _simple_run(
bool: Should continue to run.
"""
for func in matcher_list:
signature = func.signature
frame = func.frame
line_number = frame.f_lineno
file_name = frame.f_code.co_filename
signature: inspect.Signature = func.signature
frame: FrameType = func.frame
line_number: int = frame.f_lineno
file_name: str = frame.f_code.co_filename
handler = func.function
session_args = [func.matcher, event, config, *extra_args]
session_args = [func.matcher, event, *extra_args] + (
[config] if config else []
)
session_kwargs: dict[str, Any] = deepcopy(extra_kwargs)
runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory
k: v
Expand All @@ -364,13 +366,13 @@ async def _simple_run(
# These args/kwargs will be generated by Depends
if runtime_args or runtime_kwargs:
if not await cls._do_runtime_resolve(
runtime_args,
runtime_kwargs,
session_args,
session_kwargs,
session_args,
session_kwargs,
exception_ignored,
runtime_args=runtime_args,
runtime_kwargs=runtime_kwargs,
args2update=session_args,
kwargs2update=session_kwargs,
session_args=session_args,
session_kwargs=session_kwargs,
exception_ignored=exception_ignored,
):
raise RuntimeError("Runtime arguments cannot be resolved")

Expand All @@ -393,13 +395,13 @@ async def _simple_run(
continue
# Do kwargs dependency injection
if d_kw and not await cls._do_runtime_resolve(
{},
d_kw,
[],
f_kwargs,
session_args,
session_kwargs,
exception_ignored,
runtime_args={},
runtime_kwargs=d_kw,
args2update=[],
kwargs2update=f_kwargs,
session_args=session_args,
session_kwargs=session_kwargs,
exception_ignored=exception_ignored,
):
continue

Expand All @@ -414,7 +416,7 @@ async def _simple_run(
)
continue
except Exception as e:
if isinstance(e, CancelException | BlockException):
if isinstance(e, CancelException):
logger.info("Cancelled Matcher processing")
return False
elif isinstance(e, ChatException):
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Expand All @@ -437,16 +439,8 @@ async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
/,
exception_ignored: tuple[type[Exception], ...] = (),
) -> None: ...

@overload
@classmethod
async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
*args: Any,
exception_ignored: tuple[type[BaseException], ...] = (),
**kwargs: Any,
) -> None: ...

Expand All @@ -455,44 +449,36 @@ async def trigger_event(
async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
*args: Any,
) -> None: ...
@overload
@classmethod
async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
*args: Any,
exception_ignored: tuple[type[Exception], ...] = (),
exception_ignored: tuple[type[BaseException], ...] = (),
**kwargs: Any,
) -> None: ...
@overload
@classmethod
async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
*args: Any,
exception_ignored: tuple[type[Exception], ...] = (),
config: None = None,
exception_ignored: tuple[type[BaseException], ...] = (),
**kwargs: Any,
) -> None: ...

@overload
@classmethod
async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
*args: Any,
**kwargs: Any,
) -> None: ...
@classmethod
async def trigger_event(
cls,
event: BaseEvent,
config: AmritaConfig,
*args: Any,
exception_ignored: tuple[type[Exception], ...] = (),
config: AmritaConfig | None = None,
exception_ignored: tuple[type[BaseException], ...] = (),
**kwargs,
) -> None:
"""Trigger a specific type of event and call all registered event handlers for that type.
Expand All @@ -514,9 +500,6 @@ async def trigger_event(
config = i
if not event:
raise RuntimeError("No event found in args")
elif not config:
raise RuntimeError("No config found in args")

session_kwargs = kwargs
event_type: EventTypeEnum | str = event.get_event_type() # Get event type
handlers = EventRegistry().get_handlers(event_type)
Expand All @@ -529,10 +512,10 @@ async def trigger_event(
if not await cls._simple_run(
handlers[priority],
event,
config,
exception_ignored,
args,
session_kwargs,
exception_ignored=exception_ignored,
extra_args=args,
extra_kwargs=session_kwargs,
config=config,
):
break
else:
Expand Down
Loading
Loading