Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@ jobs:
build:
runs-on: ubuntu-latest
timeout-minutes: 30
strategy:
matrix:
python-version: [ "3.10", "3.11", "3.12", "3.13" ]
fail-fast: false
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5

- name: Set up Python
run: uv python install
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}

- name: Set up environment
run: |
Expand All @@ -35,16 +39,17 @@ jobs:
run: |
PYTHON_BIN="$(uv run python -c 'import sys; print(sys.executable)')"
echo "PYTHON_BIN=$PYTHON_BIN" >> $GITHUB_ENV

- name: Run Pyright
uses: jakebailey/pyright-action@v2
with:
python-path: ${{ env.PYTHON_BIN }}
pylance-version: latest-release

- name: Run Unit Tests with JUnit XML output
run: uv run pytest tests/ --cov=src/amrita_core --cov-report=term-missing --cov-report=xml --junitxml=test-results.xml -v

run: uv run pytest tests/ --cov=src/amrita_core --cov-report=term-missing
--cov-report=xml --junitxml=test-results.xml -v

- name: Publish Test Report
uses: dorny/test-reporter@v1
if: success() || failure()
Expand All @@ -58,6 +63,18 @@ jobs:
uses: astral-sh/ruff-action@v3
with:
args: check . --exit-non-zero-on-fix

- name: Build package
run: uv build
run: uv build
all-matrix-jobs-passed:
runs-on: ubuntu-latest
needs: build
if: always()
steps:
- name: Verify all matrix jobs succeeded
run: |
if [ "${{ needs.build.result }}" != "success" ]; then
echo "❌ Some matrix jobs of 'build' failed or were cancelled."
exit 1
fi
echo "✅ All matrix jobs passed!"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "amrita_core"
version = "0.8.4.1"
version = "0.8.5"
description = "High performance, lightweight agent framework."
readme = "README.md"
requires-python = ">=3.10,<3.15"
Expand Down
2 changes: 2 additions & 0 deletions src/amrita_core/builtins/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ async def call_api(
completion_tokens=last_msg.usage.output_tokens,
total_tokens=last_msg.usage.input_tokens + last_msg.usage.output_tokens,
)
text_content = text_resp.getvalue()
yield text_content

yield UniResponse(
content=text_resp.getvalue(),
Expand Down
5 changes: 3 additions & 2 deletions src/amrita_core/builtins/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ async def _execute_tool_loop(
return False

result_msg_list: list[ToolResult] = []
ret: bool = True
for tool_call in tool_calls:
function_name = tool_call.function.name
function_args: dict[str, Any] = json.loads(tool_call.function.arguments)
Expand Down Expand Up @@ -430,12 +431,12 @@ async def _execute_tool_loop(
},
)
)
return False
ret = False

# Send tool call info to user
await self._notify_tool_calls(result_msg_list, function_name, tool_call.id)

return True
return ret

async def _handle_error_append(
self,
Expand Down
20 changes: 11 additions & 9 deletions src/amrita_core/chatmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@
if TYPE_CHECKING:
from .sessions import SessionData

# Global lock for thread-safe operations in the chat manager
LOCK = aiologic.Lock()


RESPONSE_CALLBACK_TYPE = Callable[[RESPONSE_TYPE], Awaitable[Any]] | None

# Type vars
Expand Down Expand Up @@ -432,6 +428,7 @@ class ChatObject(SuspendObjectStream[RESPONSE_TYPE]):
_err: BaseException | None = None
_hook_kwargs: dict[str, Any]
_hook_args: tuple[Any, ...]
_chatman: ChatManager

_raised_exc: tuple[type[BaseException], ...]

Expand All @@ -446,6 +443,7 @@ def __init__(
preset: ModelPreset | None = None,
auto_create_session: bool = False,
*,
chat_man: ChatManager | None = None,
train_template: Template = DEFAULT_TEMPLATE,
jinja2_vars: dict[str, Any] | None = None,
agent_strategy: type[AgentStrategy] = ReActAgentStrategy,
Expand All @@ -467,13 +465,15 @@ def __init__(
preset (ModelPreset | None, optional): Preset used for this call. Defaults to None.
auto_create_session (bool, optional): Whether to automatically create a session if it does not exist. Defaults to False.
jinja2_vars (dict[str, Any] | None, optional): Variables to be passed to the template system. Defaults to None.
chat_man (ChatManager | None, optional): ChatManager that ChatObject will be bound to. Defaults to None(Global ChatManager).
train_template (Template, optional): Jinja2 template used to format system message.
agent_strategy (type[AgentStrategy], optional): Agent strategy to be used for execution. Defaults to ReActAgentStrategy.
hook_args (tuple[Any, ...], optional): Arguments could be passed to the Matcher function. Defaults to ().
hook_kwargs (dict[str, Any] | None, optional): Keyword arguments could be passed to the Matcher function. Defaults to None.
exception_ignored (tuple[type[BaseException], ...], optional): These exceptions will be raised again if they are raised in the Matcher function. Defaults to ().
queue_size (int, optional): Maximum number of message chunks to be stored in the queue. Defaults to 45.
"""
global chat_manager
sm = SessionsManager()
if auto_create_session and not sm.is_session_registered(session_id):
sm.init_session(session_id)
Expand Down Expand Up @@ -501,6 +501,7 @@ def __init__(
self.extra_usage = UniResponseUsage(
prompt_tokens=0, completion_tokens=0, total_tokens=0
)
self._chatman = chat_man or chat_manager
# other
self.last_call = datetime.now(utc)
self.preset = preset or (
Expand Down Expand Up @@ -630,15 +631,15 @@ async def _entry(self) -> None:

try:
self._is_running = True
await chat_manager.add_chat_object(self)
await self._chatman.add_chat_object(self)

await self._run()
finally:
self._is_running = False
self._is_done = True
self.end_at = datetime.now(utc)
chat_manager.running_chat_object_id2map.pop(self.stream_id, None)
if chat_manager.clean_obj(
self._chatman.running_chat_object_id2map.pop(self.stream_id, None)
if self._chatman.clean_obj(
self.session_id, 10000
): # A hard limit just to avoid memory leaks
logger.warning(
Expand Down Expand Up @@ -899,6 +900,7 @@ class ChatManager:
default_factory=lambda: defaultdict(list)
)
running_chat_object_id2map: dict[str, ChatObjectMeta] = field(default_factory=dict)
_lock: aiologic.Lock = field(default_factory=aiologic.Lock)

def clean_obj(self, k: str, maxitems: int = 10) -> bool:
"""
Expand Down Expand Up @@ -948,7 +950,7 @@ async def clean_chat_objects(self, maxitems: int = 10) -> None:
"""
Asynchronously clean up all running chat objects, limiting the number of objects for each key to no more than 10
"""
async with LOCK:
async with self._lock:
for key in self.running_chat_object.keys():
self.clean_obj(key, maxitems)

Expand All @@ -959,7 +961,7 @@ async def add_chat_object(self, chat_object: ChatObject) -> None:
Args:
chat_object (ChatObject): Chat object instance
"""
async with LOCK:
async with self._lock:
meta: ChatObjectMeta = chat_object.get_snapshot()
self.running_chat_object_id2map[chat_object.stream_id] = meta
key = chat_object.session_id
Expand Down
Loading
Loading