diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d706704..72d1a6e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,6 +16,10 @@ jobs: build: runs-on: ubuntu-latest timeout-minutes: 30 + strategy: + matrix: + python-version: [ "3.10", "3.11", "3.12", "3.13" ] + fail-fast: false steps: - name: Checkout code uses: actions/checkout@v4 @@ -23,8 +27,8 @@ jobs: - name: Install uv uses: astral-sh/setup-uv@v5 - - name: Set up Python - run: uv python install + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} - name: Set up environment run: | @@ -35,16 +39,17 @@ jobs: run: | PYTHON_BIN="$(uv run python -c 'import sys; print(sys.executable)')" echo "PYTHON_BIN=$PYTHON_BIN" >> $GITHUB_ENV - + - name: Run Pyright uses: jakebailey/pyright-action@v2 with: python-path: ${{ env.PYTHON_BIN }} pylance-version: latest-release - + - name: Run Unit Tests with JUnit XML output - run: uv run pytest tests/ --cov=src/amrita_core --cov-report=term-missing --cov-report=xml --junitxml=test-results.xml -v - + run: uv run pytest tests/ --cov=src/amrita_core --cov-report=term-missing + --cov-report=xml --junitxml=test-results.xml -v + - name: Publish Test Report uses: dorny/test-reporter@v1 if: success() || failure() @@ -58,6 +63,18 @@ jobs: uses: astral-sh/ruff-action@v3 with: args: check . --exit-non-zero-on-fix - + - name: Build package - run: uv build \ No newline at end of file + run: uv build + all-matrix-jobs-passed: + runs-on: ubuntu-latest + needs: build + if: always() + steps: + - name: Verify all matrix jobs succeeded + run: | + if [ "${{ needs.build.result }}" != "success" ]; then + echo "❌ Some matrix jobs of 'build' failed or were cancelled." + exit 1 + fi + echo "✅ All matrix jobs passed!" diff --git a/pyproject.toml b/pyproject.toml index 955a4c7..8c50d8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "amrita_core" -version = "0.8.4.1" +version = "0.8.5" description = "High performance, lightweight agent framework." readme = "README.md" requires-python = ">=3.10,<3.15" diff --git a/src/amrita_core/builtins/adapter.py b/src/amrita_core/builtins/adapter.py index 400abbd..5efeea1 100644 --- a/src/amrita_core/builtins/adapter.py +++ b/src/amrita_core/builtins/adapter.py @@ -280,6 +280,8 @@ async def call_api( completion_tokens=last_msg.usage.output_tokens, total_tokens=last_msg.usage.input_tokens + last_msg.usage.output_tokens, ) + text_content = text_resp.getvalue() + yield text_content yield UniResponse( content=text_resp.getvalue(), diff --git a/src/amrita_core/builtins/agent.py b/src/amrita_core/builtins/agent.py index fc9103b..e8d1945 100644 --- a/src/amrita_core/builtins/agent.py +++ b/src/amrita_core/builtins/agent.py @@ -345,6 +345,7 @@ async def _execute_tool_loop( return False result_msg_list: list[ToolResult] = [] + ret: bool = True for tool_call in tool_calls: function_name = tool_call.function.name function_args: dict[str, Any] = json.loads(tool_call.function.arguments) @@ -430,12 +431,12 @@ async def _execute_tool_loop( }, ) ) - return False + ret = False # Send tool call info to user await self._notify_tool_calls(result_msg_list, function_name, tool_call.id) - return True + return ret async def _handle_error_append( self, diff --git a/src/amrita_core/chatmanager.py b/src/amrita_core/chatmanager.py index 4eafefc..b975682 100644 --- a/src/amrita_core/chatmanager.py +++ b/src/amrita_core/chatmanager.py @@ -66,10 +66,6 @@ if TYPE_CHECKING: from .sessions import SessionData -# Global lock for thread-safe operations in the chat manager -LOCK = aiologic.Lock() - - RESPONSE_CALLBACK_TYPE = Callable[[RESPONSE_TYPE], Awaitable[Any]] | None # Type vars @@ -432,6 +428,7 @@ class ChatObject(SuspendObjectStream[RESPONSE_TYPE]): _err: BaseException | None = None _hook_kwargs: dict[str, Any] _hook_args: tuple[Any, ...] + _chatman: ChatManager _raised_exc: tuple[type[BaseException], ...] @@ -446,6 +443,7 @@ def __init__( preset: ModelPreset | None = None, auto_create_session: bool = False, *, + chat_man: ChatManager | None = None, train_template: Template = DEFAULT_TEMPLATE, jinja2_vars: dict[str, Any] | None = None, agent_strategy: type[AgentStrategy] = ReActAgentStrategy, @@ -467,6 +465,7 @@ def __init__( preset (ModelPreset | None, optional): Preset used for this call. Defaults to None. auto_create_session (bool, optional): Whether to automatically create a session if it does not exist. Defaults to False. jinja2_vars (dict[str, Any] | None, optional): Variables to be passed to the template system. Defaults to None. + chat_man (ChatManager | None, optional): ChatManager that ChatObject will be bound to. Defaults to None(Global ChatManager). train_template (Template, optional): Jinja2 template used to format system message. agent_strategy (type[AgentStrategy], optional): Agent strategy to be used for execution. Defaults to ReActAgentStrategy. hook_args (tuple[Any, ...], optional): Arguments could be passed to the Matcher function. Defaults to (). @@ -474,6 +473,7 @@ def __init__( exception_ignored (tuple[type[BaseException], ...], optional): These exceptions will be raised again if they are raised in the Matcher function. Defaults to (). queue_size (int, optional): Maximum number of message chunks to be stored in the queue. Defaults to 45. """ + global chat_manager sm = SessionsManager() if auto_create_session and not sm.is_session_registered(session_id): sm.init_session(session_id) @@ -501,6 +501,7 @@ def __init__( self.extra_usage = UniResponseUsage( prompt_tokens=0, completion_tokens=0, total_tokens=0 ) + self._chatman = chat_man or chat_manager # other self.last_call = datetime.now(utc) self.preset = preset or ( @@ -630,15 +631,15 @@ async def _entry(self) -> None: try: self._is_running = True - await chat_manager.add_chat_object(self) + await self._chatman.add_chat_object(self) await self._run() finally: self._is_running = False self._is_done = True self.end_at = datetime.now(utc) - chat_manager.running_chat_object_id2map.pop(self.stream_id, None) - if chat_manager.clean_obj( + self._chatman.running_chat_object_id2map.pop(self.stream_id, None) + if self._chatman.clean_obj( self.session_id, 10000 ): # A hard limit just to avoid memory leaks logger.warning( @@ -899,6 +900,7 @@ class ChatManager: default_factory=lambda: defaultdict(list) ) running_chat_object_id2map: dict[str, ChatObjectMeta] = field(default_factory=dict) + _lock: aiologic.Lock = field(default_factory=aiologic.Lock) def clean_obj(self, k: str, maxitems: int = 10) -> bool: """ @@ -948,7 +950,7 @@ async def clean_chat_objects(self, maxitems: int = 10) -> None: """ Asynchronously clean up all running chat objects, limiting the number of objects for each key to no more than 10 """ - async with LOCK: + async with self._lock: for key in self.running_chat_object.keys(): self.clean_obj(key, maxitems) @@ -959,7 +961,7 @@ async def add_chat_object(self, chat_object: ChatObject) -> None: Args: chat_object (ChatObject): Chat object instance """ - async with LOCK: + async with self._lock: meta: ChatObjectMeta = chat_object.get_snapshot() self.running_chat_object_id2map[chat_object.stream_id] = meta key = chat_object.session_id diff --git a/src/amrita_core/hook/matcher.py b/src/amrita_core/hook/matcher.py index ae35bf0..9570fb2 100644 --- a/src/amrita_core/hook/matcher.py +++ b/src/amrita_core/hook/matcher.py @@ -1,10 +1,13 @@ from __future__ import annotations import asyncio +import datetime import inspect from collections import defaultdict -from collections.abc import Awaitable, Callable, Iterable +from collections.abc import Awaitable, Callable, Hashable, Iterable from copy import deepcopy +from dataclasses import asdict, dataclass +from dataclasses import field as Field from types import FrameType, MappingProxyType from typing import ( Any, @@ -14,14 +17,16 @@ TypeVar, overload, ) +from uuid import UUID, uuid4 +import aiologic from exceptiongroup import ExceptionGroup -from pydantic import BaseModel, Field from typing_extensions import Never, Self from amrita_core.config import AmritaConfig from amrita_core.hook.event import EventTypeEnum from amrita_core.logging import debug_log, logger +from amrita_core.weakcache import WeakValueLRUCache from .event import BaseEvent from .exception import ( @@ -33,12 +38,16 @@ ChatException: TypeAlias = MatcherException -class FunctionData(BaseModel, arbitrary_types_allowed=True): - function: Callable[..., Awaitable[Any]] = Field(...) - signature: inspect.Signature = Field(...) - frame: FrameType = Field(...) - priority: int = Field(...) - matcher: Matcher = Field(...) +@dataclass +class FunctionData: + function: Callable[..., Awaitable[Any]] = Field() + signature: inspect.Signature = Field() + frame: FrameType = Field() + priority: int = Field() + matcher: Matcher = Field() + + def model_dump(self): + return asdict(self) class EventRegistry: @@ -62,20 +71,35 @@ def get_all(self) -> defaultdict[str, defaultdict[int, list[FunctionData]]]: return self._event_handlers -class Matcher: - def __init__(self, event_type: str, priority: int = 10, block: bool = True): +class Matcher(Hashable): + _dead_at: datetime.datetime | None = None + id: UUID + + def __init__( + self, + event_type: str, + priority: int = 10, + block: bool = True, + dead_at: datetime.datetime | None = None, + ): """Constructor, initialize Matcher object. Args: event_type (str): Event type priority (int, optional): Priority. Defaults to 10. block (bool, optional): Whether to block subsequent events. Defaults to True. + dead_at (datetime.datetime | None, optional): Deadline for this matcher. Defaults to None. """ if priority <= 0: raise ValueError("Event priority cannot be zero or negative!") - self.event_type = event_type - self.priority = priority - self.block = block + self.event_type: str = event_type + self.priority: int = priority + self.block: bool = block + self._dead_at = dead_at + self.id = uuid4() + + def __hash__(self): + return hash(self.id.bytes) def append_handler(self, func: Callable[..., Awaitable[Any]]): frame = inspect.currentframe() @@ -117,6 +141,10 @@ def pass_event(self) -> Never: """ raise PassException() # pragma: no cover + @property + def dead(self) -> bool: + return self._dead_at is not None and self._dead_at < datetime.datetime.now() + T = TypeVar("T") @@ -189,6 +217,17 @@ class MatcherFactory: Event handling factory class. """ + _lock_pool: ClassVar[WeakValueLRUCache[str, aiologic.Lock]] = WeakValueLRUCache( + capacity=1024, loose_mode=True + ) + + @classmethod + def _repo_lock(cls, category: str) -> aiologic.Lock: + if (lock := cls._lock_pool.get(category)) is None: + lock = aiologic.Lock() + cls._lock_pool[category] = lock + return lock + @staticmethod def _resolve_dependencies( signature: inspect.Signature, @@ -345,92 +384,107 @@ async def _simple_run( Returns: bool: Should continue to run. """ - for func in matcher_list: - signature: inspect.Signature = func.signature - frame: FrameType = func.frame - line_number: int = frame.f_lineno - file_name: str = frame.f_code.co_filename - handler = func.function - session_args = [func.matcher, event, *extra_args] + ( - [config] if config else [] - ) - session_kwargs: dict[str, Any] = deepcopy(extra_kwargs) - runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory - k: v - for k, v in enumerate(session_args) - if isinstance(v, DependsFactory) - } - runtime_kwargs = { - k: v for k, v in session_kwargs.items() if isinstance(v, DependsFactory) - } - # These args/kwargs will be generated by Depends - if runtime_args or runtime_kwargs: - if not await cls._do_runtime_resolve( - runtime_args=runtime_args, - runtime_kwargs=runtime_kwargs, - args2update=session_args, - kwargs2update=session_kwargs, + _dead_to_remove: list[FunctionData] = [] + try: + for func in matcher_list: + matcher: Matcher = func.matcher + if matcher.dead: + _dead_to_remove.append(func) + continue + signature: inspect.Signature = func.signature + frame: FrameType = func.frame + line_number: int = frame.f_lineno + file_name: str = frame.f_code.co_filename + handler = func.function + session_args = [matcher, event, *extra_args] + ( + [config] if config else [] + ) + session_kwargs: dict[str, Any] = deepcopy(extra_kwargs) + runtime_args: dict[int, DependsFactory] = { # index -> DependsFactory + k: v + for k, v in enumerate(session_args) + if isinstance(v, DependsFactory) + } + runtime_kwargs = { + k: v + for k, v in session_kwargs.items() + if isinstance(v, DependsFactory) + } + # These args/kwargs will be generated by Depends + if runtime_args or runtime_kwargs: + if not await cls._do_runtime_resolve( + runtime_args=runtime_args, + runtime_kwargs=runtime_kwargs, + args2update=session_args, + kwargs2update=session_kwargs, + session_args=session_args, + session_kwargs=session_kwargs, + exception_ignored=exception_ignored, + ): + raise RuntimeError("Runtime arguments cannot be resolved") + + success, new_args, f_kwargs, d_kw = ( + MatcherFactory._resolve_dependencies( + signature, session_args, session_kwargs + ) + ) + if not success: + failed_args = list( + { + k: v + for k, v in signature.parameters.items() + if v.annotation is inspect._empty + }.keys() + ) + if failed_args: + logger.warning( + f"Matcher {func.function.__name__} (File: {file_name}: Line {frame.f_lineno!s}) has untyped parameters!" + + f"(Args:{''.join(i + ',' for i in failed_args)}).Skipping......" + ) + continue + # Do kwargs dependency injection + if d_kw and not await cls._do_runtime_resolve( + runtime_args={}, + runtime_kwargs=d_kw, + args2update=[], + kwargs2update=f_kwargs, session_args=session_args, session_kwargs=session_kwargs, exception_ignored=exception_ignored, ): - raise RuntimeError("Runtime arguments cannot be resolved") + continue - success, new_args, f_kwargs, d_kw = MatcherFactory._resolve_dependencies( - signature, session_args, session_kwargs - ) - if not success: - failed_args = list( - { - k: v - for k, v in signature.parameters.items() - if v.annotation is inspect._empty - }.keys() - ) - if failed_args: - logger.warning( - f"Matcher {func.function.__name__} (File: {file_name}: Line {frame.f_lineno!s}) has untyped parameters!" - + f"(Args:{''.join(i + ',' for i in failed_args)}).Skipping......" + # Call the handler + try: + logger.info(f"Starting to run Matcher: '{handler.__name__}'") + + await handler(*new_args, **f_kwargs) + except PassException: + logger.info( + f"Matcher '{handler.__name__}'(~{file_name}:{line_number}) was skipped" ) - continue - # Do kwargs dependency injection - if d_kw and not await cls._do_runtime_resolve( - runtime_args={}, - runtime_kwargs=d_kw, - args2update=[], - kwargs2update=f_kwargs, - session_args=session_args, - session_kwargs=session_kwargs, - exception_ignored=exception_ignored, - ): - continue - - # Call the handler - try: - logger.info(f"Starting to run Matcher: '{handler.__name__}'") - - await handler(*new_args, **f_kwargs) - except PassException: - logger.info( - f"Matcher '{handler.__name__}'(~{file_name}:{line_number}) was skipped" - ) - continue - except Exception as e: - if isinstance(e, CancelException): - logger.info("Cancelled Matcher processing") - return False - elif isinstance(e, ChatException): - raise - elif exception_ignored and isinstance(e, exception_ignored): - raise - logger.opt(exception=e, colors=True).error( - f"An error occurred while running '{handler.__name__}'({file_name}:{line_number}) " - ) - continue - finally: - logger.info(f"Handler {handler.__name__} finished") - if func.matcher.block: - return False + continue + except Exception as e: + if isinstance(e, CancelException): + logger.info("Cancelled Matcher processing") + return False + elif isinstance(e, ChatException): + raise + elif exception_ignored and isinstance(e, exception_ignored): + raise + logger.opt(exception=e, colors=True).error( + f"An error occurred while running '{handler.__name__}'({file_name}:{line_number}) " + ) + continue + finally: + logger.info(f"Handler {handler.__name__} finished") + + if matcher.block: + return False + finally: + if _dead_to_remove: + for func in _dead_to_remove: + matcher_list.remove(func) return True @overload @@ -502,26 +556,29 @@ async def trigger_event( raise RuntimeError("No event found in args") session_kwargs = kwargs event_type: EventTypeEnum | str = event.get_event_type() # Get event type - handlers = EventRegistry().get_handlers(event_type) - priorities: list[int] = sorted(handlers.keys(), reverse=False) - debug_log(f"Running matchers for event: {event_type}!") - # Check if there are handlers for this event type - if priorities: - for priority in priorities: - logger.info(f"Running matchers for priority {priority}......") - if not await cls._simple_run( - handlers[priority], - event, - exception_ignored=exception_ignored, - extra_args=args, - extra_kwargs=session_kwargs, - config=config, - ): - break - else: - logger.warning( - f"No registered Matcher for {event_type} event, skipping processing." + async with cls._repo_lock(event_type): + handlers: defaultdict[int, list[FunctionData]] = ( + EventRegistry().get_handlers(event_type) ) + priorities: list[int] = sorted(handlers.keys(), reverse=False) + debug_log(f"Running matchers for event: {event_type}!") + # Check if there are handlers for this event type + if priorities: + for priority in priorities: + logger.info(f"Running matchers for priority {priority}......") + if not await cls._simple_run( + handlers[priority], + event, + exception_ignored=exception_ignored, + extra_args=args, + extra_kwargs=session_kwargs, + config=config, + ): + break + else: + logger.warning( + f"No registered Matcher for {event_type} event, skipping processing." + ) MatcherManager = MatcherFactory diff --git a/src/amrita_core/tools/mcp.py b/src/amrita_core/tools/mcp.py index 2b36777..7aa0e7e 100644 --- a/src/amrita_core/tools/mcp.py +++ b/src/amrita_core/tools/mcp.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import contextlib import json import random @@ -38,23 +39,39 @@ class MCPClient: mcp_client: Client | None = None server_script: MCP_SERVER_SCRIPT_TYPE + _close_waitter: asyncio.Future | None = None + _close_ttl: int def __init__( self, server_script: MCP_SERVER_SCRIPT_TYPE, + connection_ttl: int = 60, # When the connection is not used for a long time, it will be closed # headers: dict | None = None, ): + """Constructor + + Args: + server_script (MCP_SERVER_SCRIPT_TYPE): Server script + connection_ttl (int, optional): TTL for connection when planned to close. Defaults to 60. Set to -1 to disable + """ self.mcp_client = None self.server_script: MCP_SERVER_SCRIPT_TYPE = server_script self.tools: list[MCPToolSchema] = [] self.openai_tools: list[ToolFunctionSchema] = [] + if connection_ttl < -1: + raise ValueError("connection_ttl must be greater than or equals to -1") + self._close_ttl = connection_ttl async def __aenter__(self) -> Self: - await self._connect() + self._close_waitter = None + if not self.mcp_client: + await self._connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self._close() + if self._close_ttl == -1: + return await self.close_no_wait() + self.close() async def bound_to(self, tm: ClientManager): try: @@ -94,6 +111,7 @@ async def _connect(self, update_tools: bool = False): Args: update_tools (bool, optional): whether to update the tool list. Defaults to False. """ + await self._clean_waitter() if self.mcp_client is not None: raise RuntimeError("MCP Server is already connected!") @@ -109,6 +127,43 @@ async def _connect(self, update_tools: bool = False): logger.info(f"Available tools: {[tool.name for tool in self.tools]}") self._cast_tool_to_amrita() + async def _clean_waitter(self): + if self._close_waitter is not None: + with contextlib.suppress(asyncio.CancelledError): + self._close_waitter.cancel() + await asyncio.sleep(0) + self._close_waitter = None + + def close(self) -> asyncio.Future[None]: + """Create a TTL-Task to wait for connection's close. + + Returns: + asyncio.Task[None]: Waiting task. + """ + if self._close_waitter is not None and not self._close_waitter.done(): + return self._close_waitter + if self._close_ttl == -1: + raise RuntimeError("TTL is not set. Please use `close_no_wait` instead.") + + async def waitter() -> None: + + with contextlib.suppress(asyncio.CancelledError): + await asyncio.sleep(self._close_ttl) + await self._close() + + self._close_waitter = asyncio.create_task(waitter()) + return self._close_waitter + + async def close_no_wait(self): + await self._clean_waitter() + await self._close() + + async def _close(self) -> None: + """Close connection""" + if self.mcp_client: + await self.mcp_client.__aexit__(None, None, None) + self.mcp_client = None + def _format_tools_for_openai(self): """Convert MCP tool format to OpenAI tool format""" openai_tools: list[ToolFunctionSchema] = [ @@ -142,12 +197,6 @@ def get_original_tools(self) -> list[MCPToolSchema]: """Get original MCP tool list""" return self.tools - async def _close(self): - """Close connection""" - if self.mcp_client: - await self.mcp_client.__aexit__(None, None, None) - self.mcp_client = None - class MultiClientManager(ContextThreadsafe): clients: list[MCPClient] diff --git a/src/amrita_core/weakcache.py b/src/amrita_core/weakcache.py new file mode 100644 index 0000000..263754b --- /dev/null +++ b/src/amrita_core/weakcache.py @@ -0,0 +1,274 @@ +from __future__ import annotations + +import weakref +from collections import OrderedDict +from collections.abc import Generator, Hashable, Iterator +from typing import Any, Generic, TypeVar, overload + +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") +T = TypeVar("T") + + +class WeakValueLRUCache(Generic[K, V]): + """Weak reference LRU cache implementation. + + Always used for locks pool. + """ + + __marker = object() + _capacity: int + _cache: OrderedDict[K, weakref.ReferenceType[V]] + _loose_mode: bool + + def __init__( + self, + *, + capacity: int, + loose_mode: bool = False, + items: dict[K, V] | None = None, + ): + """Constructor of WeakValueLRUCache + + Args: + capacity (int): Size of cache. + loose_mode (bool, optional): When the length of items is out of capacity, still allowed to add item to cache. Defaults to False. + items (dict[K, V] | None, optional): Initial items. Defaults to None. + + Raises: + ValueError: Raised when capacity is not positive. + """ + if capacity < 0: + raise ValueError("Capacity must be a positive integer") + self._capacity = capacity + self._loose_mode = loose_mode + self._cache: OrderedDict[K, weakref.ref[V]] = OrderedDict() + if items: + for key, value in items.items(): + self.put(key, value) + + def _cleanup_key_if_expired(self, key: K) -> bool: + """Clean a key if it's expired. + + Args: + key (K): Key in this cache. + + Returns: + bool: True if this key hasn't expired. Otherwise, return False. + """ + if key not in self._cache: + return False + + weak_ref = self._cache[key] + if weak_ref() is None: + self._cache.pop(key, None) + return False + return True + + def resize(self, new_size: int): + """Resize the cache. + + Args: + new_size (int): New cache size. + """ + self._capacity = new_size + + def set_loose(self, loose: bool): + """Set loose mode. + + Args: + loose (bool): Loose mode. + """ + self._loose_mode = loose + + @property + def loose(self) -> bool: + """Get loose mode. + + Returns: + bool: Loose mode. + """ + return self._loose_mode + + @overload + def get(self, key: K) -> V | None: ... + @overload + def get(self, key: K, default: T) -> V | T: ... + def get(self, key: K, default: T = None) -> V | T: + """Get a value from cache. + + Args: + key (K): Key in this cache. + + Returns: + V | None: Value in this cache. + """ + + if key not in self._cache: + return default + + weak_ref = self._cache[key] + value: V | None = weak_ref() + + if value is None: + self._cache.pop(key, None) + return default + + self._cache.pop(key) + self._cache[key] = weak_ref + return value + + def put(self, key: K, value: V) -> None: + """Put a value into cache. + + Args: + key (K): Key in this cache. + value (V): Value in this cache. + """ + if value is None: + raise ValueError("Cannot store None value in WeakValueLRUCache") + + weak_ref: weakref.ReferenceType[V] = weakref.ref(value) + + if key in self._cache: + self._cache.pop(key) + else: + should_expire_count = max(0, (len(self._cache) + 1) - self._capacity) + collected = 0 + for _ in range(len(self._cache)): + if collected >= should_expire_count: + break + oldest_key: K = next(iter(self._cache)) + oldest_ref = self._cache[oldest_key] + if oldest_ref() is None or not self._loose_mode: + collected += 1 + self._cache.pop(oldest_key) + elif self._loose_mode: + self._cache.move_to_end(oldest_key) + + self._cache[key] = weak_ref + + def expire(self, length: int | None = None) -> None: + """Expire cache of given length + + Args: + length (int | None, optional): Length. Defaults to None. + """ + if length is None: + length = int(len(self._cache) * (1 / 5)) + keys_to_check = list(self._cache.keys())[: min(length, len(self._cache))] + expired_keys = [key for key in keys_to_check if self._cache[key]() is None] + + for key in expired_keys: + self._cache.pop(key, None) + + def __getitem__(self, key: K) -> V: + value = self.get(key) + if value is None: + raise KeyError(key) + return value + + def __setitem__(self, key: K, value: V) -> None: + self.put(key, value) + + def __delitem__(self, key: K) -> None: + + if key not in self._cache: + raise KeyError(key) + del self._cache[key] + + def __contains__(self, key: K) -> bool: + if key not in self._cache: + return False + return self._cache[key]() is not None + + def __len__(self) -> int: + """!!!This will return the number of non-expired items in the cache.!!!""" + return len(self._cache) + + def __iter__(self) -> Iterator[K]: + for key in list(self._cache.keys()): + if self._cleanup_key_if_expired(key): + yield key + + def keys(self) -> Iterator[K]: + """Return a iterator that yield keys. + + Yields: + Iterator[Hashable]: Iterator + """ + return self.__iter__() + + def values(self) -> Generator[V, Any, None]: + for _, value in self.items(): # noqa: PERF102 + yield value + + def items(self) -> Generator[tuple[K, V], Any, None]: + + for key, weak_ref in list(self._cache.items()): + if self._cleanup_key_if_expired(key): + value = weak_ref() + assert value is not None + yield key, value + + def clear(self) -> None: + """Remove all items from cache.""" + self._cache.clear() + + @property + def capacity(self) -> int: + """Get cache capacity.""" + return self._capacity + + def get_capacity(self) -> int: + """Get cache capacity.""" + return self._capacity + + def size(self) -> int: + """Return the valid size of cache.""" + t = 0 + for i in self._cache.values(): + if i() is not None: + t += 1 + return t + + def is_full(self) -> bool: + """Check if cache is full.""" + return len(self._cache) >= self._capacity + + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: T) -> V | T: ... + + def pop(self, key: K, default: T = __marker) -> V | T: + """Remove and return item associated with key if key is in cache, else default. + + Args: + key (K): Key in this cache. + default (T, optional): Default value. Defaults to __marker. + + Returns: + V | T: Value in this cache. + """ + if key not in self._cache: + if default is self.__marker: + raise KeyError(key) + return default + + weak_ref = self._cache.pop(key) + value = weak_ref() + if value is None: + if default is self.__marker: + raise KeyError(key) + return default + return value + + def __repr__(self) -> str: + + items = [] + for k, weak_ref in self._cache.items(): + v = weak_ref() + if v is not None: + items.append(f"{k!r}: {v!r}") + return f"{self.__class__.__name__}(capacity={self._capacity}, items={{{', '.join(items)}}})" diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 2db8848..8e4f65e 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,5 +1,5 @@ # type: ignore -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from openai import AsyncStream @@ -13,7 +13,7 @@ Function as ToolCallFunction, ) -from amrita_core.builtins.adapter import OpenAIAdapter +from amrita_core.builtins.adapter import AnthropicAdapter, OpenAIAdapter from amrita_core.config import AmritaConfig from amrita_core.tools.models import ToolFunctionSchema from amrita_core.types import ModelPreset, ToolCall, UniResponse @@ -99,92 +99,6 @@ async def test_call_api_non_streaming(self, adapter, mock_messages): assert results[1].usage is not None assert results[1].usage.prompt_tokens == 10 - @pytest.mark.asyncio - async def test_call_api_streaming(self, adapter, mock_messages): - """Test call_api with streaming response""" - # Set stream to True - adapter.preset.config.stream = True - - # Create mock chunks - chunk1 = ChatCompletionChunk( - id="chatcmpl-123", - choices=[ - {"index": 0, "delta": {"content": "Hello"}, "finish_reason": None} - ], - created=1234567890, - model="gpt-3.5-turbo", - object="chat.completion.chunk", - ) - chunk2 = ChatCompletionChunk( - id="chatcmpl-123", - choices=[ - {"index": 0, "delta": {"content": " there!"}, "finish_reason": "stop"} - ], - created=1234567890, - model="gpt-3.5-turbo", - object="chat.completion.chunk", - usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, - ) - - with patch("amrita_core.builtins.adapter.openai.AsyncOpenAI") as mock_openai: - mock_client = AsyncMock() - mock_stream = MockAsyncStream([chunk1, chunk2]) - mock_client.chat.completions.create.return_value = mock_stream - mock_openai.return_value = mock_client - - # Call the method - results = [] - async for result in adapter.call_api(mock_messages): - results.append(result) - - # Verify results - assert len(results) == 3 # "Hello" + " there!" + UniResponse - assert results[0] == "Hello" - assert results[1] == " there!" - assert isinstance(results[2], UniResponse) - assert results[2].content == "Hello there!" - assert results[2].usage is not None - - @pytest.mark.asyncio - async def test_call_api_streaming_with_empty_content(self, adapter, mock_messages): - """Test call_api with streaming response that has empty content chunks""" - adapter.preset.config.stream = True - - # Create mock chunks with some empty content - chunk1 = ChatCompletionChunk( - id="chatcmpl-123", - choices=[{"index": 0, "delta": {"content": ""}, "finish_reason": None}], - created=1234567890, - model="gpt-3.5-turbo", - object="chat.completion.chunk", - ) - chunk2 = ChatCompletionChunk( - id="chatcmpl-123", - choices=[ - {"index": 0, "delta": {"content": "Hello"}, "finish_reason": "stop"} - ], - created=1234567890, - model="gpt-3.5-turbo", - object="chat.completion.chunk", - ) - - with patch("amrita_core.builtins.adapter.openai.AsyncOpenAI") as mock_openai: - mock_client = AsyncMock() - mock_stream = MockAsyncStream([chunk1, chunk2]) - mock_client.chat.completions.create.return_value = mock_stream - mock_openai.return_value = mock_client - - results = [] - async for result in adapter.call_api(mock_messages): - results.append(result) - - # Empty string is also yielded since it's not None - assert len(results) == 3 - assert results[0] == "" - assert results[1] == "Hello" - assert isinstance(results[2], UniResponse) - assert results[2].content == "Hello" - @pytest.mark.asyncio async def test_call_api_non_streaming_empty_content(self, adapter, mock_messages): """Test call_api with non-streaming response that has empty content""" @@ -412,3 +326,608 @@ async def test_call_tools_with_string_tool_choice(self, adapter, mock_messages): assert isinstance(result, UniResponse) assert result.content is None assert result.tool_calls is None + + +class TestAnthropicAdapter: + """Test AnthropicAdapter functionality""" + + @pytest.fixture + def anthropic_adapter(self): + """Create AnthropicAdapter instance with mock config and preset""" + config = AmritaConfig() + preset = ModelPreset( + model="claude-3-opus-20240229", + base_url="https://api.anthropic.com", + api_key="test-key", + ) + return AnthropicAdapter(config=config, preset=preset) + + @pytest.fixture + def simple_messages(self): + """Create simple messages for testing""" + return [ + { + "role": "user", + "content": "Hello!", + }, # Remove system message for tool tests + ] + + @pytest.fixture + def messages_with_tool_calls(self): + """Create messages with tool calls for testing""" + return [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "function": { + "name": "get_weather", + "arguments": '{"location": "New York"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "content": "The weather is sunny, 75°F", + }, + ] + + def test_get_adapter_protocol(self): + """Test get_adapter_protocol method""" + protocol = AnthropicAdapter.get_adapter_protocol() + assert protocol == ("anthropic", "claude") + + def test_convert_content_to_blocks_text(self): + """Test _convert_content_to_blocks with plain text""" + content = "Hello world" + blocks = AnthropicAdapter._convert_content_to_blocks(content) + expected = [{"type": "text", "text": "Hello world"}] + assert blocks == expected + + def test_convert_content_to_blocks_none(self): + """Test _convert_content_to_blocks with None""" + blocks = AnthropicAdapter._convert_content_to_blocks(None) + assert blocks == [] + + def test_convert_content_to_blocks_list_text(self): + """Test _convert_content_to_blocks with list of text content""" + content = [{"type": "text", "text": "Hello"}, {"type": "text", "text": "World"}] + blocks = AnthropicAdapter._convert_content_to_blocks(content) + expected = [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + ] + assert blocks == expected + + def test_convert_content_to_blocks_list_image(self): + """Test _convert_content_to_blocks with image content""" + content = [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + }, + ] + blocks = AnthropicAdapter._convert_content_to_blocks(content) + expected = [ + {"type": "text", "text": "Hello"}, + { + "type": "image", + "source": {"type": "url", "url": "https://example.com/image.jpg"}, + }, + ] + assert blocks == expected + + def test_convert_content_to_blocks_empty_list(self): + """Test _convert_content_to_blocks with empty list returns default text block""" + content = [] + blocks = AnthropicAdapter._convert_content_to_blocks(content) + expected = [{"type": "text", "text": ""}] + assert blocks == expected + + def test_convert_messages_system_only(self): + """Test _convert_messages with only system message""" + messages = [{"role": "system", "content": "You are an AI assistant."}] + converted = AnthropicAdapter._convert_messages(messages) + expected = [{"role": "system", "content": "You are an AI assistant."}] + assert converted == expected + + def test_convert_messages_user_text(self): + """Test _convert_messages with user text message""" + messages = [{"role": "user", "content": "Hello!"}] + converted = AnthropicAdapter._convert_messages(messages) + expected = [{"role": "user", "content": [{"type": "text", "text": "Hello!"}]}] + assert converted == expected + + def test_convert_messages_user_list_content(self): + """Test _convert_messages with user list content""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + }, + ], + } + ] + converted = AnthropicAdapter._convert_messages(messages) + expected = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + { + "type": "image", + "source": { + "type": "url", + "url": "https://example.com/image.jpg", + }, + }, + ], + } + ] + assert converted == expected + + def test_convert_messages_assistant_no_tools(self): + """Test _convert_messages with assistant message without tools""" + messages = [{"role": "assistant", "content": "Hello there!"}] + converted = AnthropicAdapter._convert_messages(messages) + expected = [ + {"role": "assistant", "content": [{"type": "text", "text": "Hello there!"}]} + ] + assert converted == expected + + def test_convert_messages_assistant_with_tools(self): + """Test _convert_messages with assistant message with tool calls""" + messages = [ + { + "role": "assistant", + "tool_calls": [ + { + "id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "function": { + "name": "get_weather", + "arguments": '{"location": "New York"}', + }, + } + ], + } + ] + converted = AnthropicAdapter._convert_messages(messages) + expected = [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "name": "get_weather", + "input": {"location": "New York"}, + } + ], + } + ] + assert converted == expected + + def test_convert_messages_tool_messages(self): + """Test _convert_messages with tool messages (should be merged into user message)""" + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "function": { + "name": "get_weather", + "arguments": '{"location": "New York"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "content": "The weather is sunny, 75°F", + }, + ] + converted = AnthropicAdapter._convert_messages(messages) + expected = [ + { + "role": "user", + "content": [{"type": "text", "text": "What's the weather?"}], + }, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "name": "get_weather", + "input": {"location": "New York"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_bdrk_01K2K2K2K2K2K2K2K2K2K2", + "content": "The weather is sunny, 75°F", + } + ], + }, + ] + assert converted == expected + + def test_convert_messages_mixed_roles(self): + """Test _convert_messages with mixed roles including system""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "assistant", "content": "Hi there!"}, + ] + converted = AnthropicAdapter._convert_messages(messages) + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}, + ] + assert converted == expected + + def test_convert_messages_system_with_list_content(self): + """Test _convert_messages with system message containing list content""" + messages = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant."}, + {"type": "text", "text": "Always be polite."}, + ], + } + ] + converted = AnthropicAdapter._convert_messages(messages) + expected = [ + { + "role": "system", + "content": "You are a helpful assistant.Always be polite.", + } + ] + assert converted == expected + + def test_convert_tools_empty(self): + """Test _convert_tools with empty tools list""" + tools = [] + converted = AnthropicAdapter._convert_tools(tools) + assert converted == [] + + def test_convert_tools_single_tool(self): + """Test _convert_tools with single tool""" + from amrita_core.tools.models import ( + FunctionDefinitionSchema, + FunctionParametersSchema, + ) + + tool_schema = ToolFunctionSchema( + function=FunctionDefinitionSchema( + name="get_weather", + description="Get the current weather", + parameters=FunctionParametersSchema( + type="object", + properties={"location": {"type": "string"}}, + required=["location"], + ), + ), + strict=False, + ) + converted = AnthropicAdapter._convert_tools([tool_schema]) + expected = [ + { + "name": "get_weather", + "description": "Get the current weather", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "No description"} + }, + "required": ["location"], + }, + "strict": False, + } + ] + assert converted == expected + + def test_convert_tools_strict_tool(self): + """Test _convert_tools with strict tool""" + from amrita_core.tools.models import ( + FunctionDefinitionSchema, + FunctionParametersSchema, + ) + + tool_schema = ToolFunctionSchema( + function=FunctionDefinitionSchema( + name="calculate", + description="Perform calculation", + parameters=FunctionParametersSchema( + type="object", + properties={"expression": {"type": "string"}}, + required=["expression"], + ), + ), + strict=True, + ) + converted = AnthropicAdapter._convert_tools([tool_schema]) + expected = [ + { + "name": "calculate", + "description": "Perform calculation", + "input_schema": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "No description", + } + }, + "required": ["expression"], + }, + "strict": True, + } + ] + assert converted == expected + + def test_convert_tool_choice_none(self): + """Test _convert_tool_choice with None""" + choice = AnthropicAdapter._convert_tool_choice(None) + expected = {"type": "auto"} + assert choice == expected + + def test_convert_tool_choice_auto(self): + """Test _convert_tool_choice with 'auto'""" + choice = AnthropicAdapter._convert_tool_choice("auto") + expected = {"type": "auto"} + assert choice == expected + + def test_convert_tool_choice_none_string(self): + """Test _convert_tool_choice with 'none'""" + choice = AnthropicAdapter._convert_tool_choice("none") + expected = {"type": "none"} + assert choice == expected + + def test_convert_tool_choice_required(self): + """Test _convert_tool_choice with 'required'""" + choice = AnthropicAdapter._convert_tool_choice("required") + expected = {"type": "any"} + assert choice == expected + + def test_convert_tool_choice_specific_function(self): + """Test _convert_tool_choice with specific ToolFunctionSchema""" + tool_schema = ToolFunctionSchema( + function={ + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {}}, + } + ) + choice = AnthropicAdapter._convert_tool_choice(tool_schema) + expected = {"type": "tool", "name": "get_weather"} + assert choice == expected + + def test_convert_tool_choice_invalid(self): + """Test _convert_tool_choice with invalid choice raises ValueError""" + with pytest.raises(ValueError, match="Invalid choice: invalid_choice"): + AnthropicAdapter._convert_tool_choice("invalid_choice") + + @pytest.mark.asyncio + async def test_call_api_non_streaming(self, anthropic_adapter, simple_messages): + """Test call_api with non-streaming Anthropic response""" + # Mock Anthropic client response + from anthropic.types import TextBlock, Usage + + mock_message = MagicMock() + mock_message.content = [TextBlock(text="Hello there!", type="text")] + mock_message.usage = Usage(input_tokens=10, output_tokens=5, total_tokens=15) + + with patch( + "amrita_core.builtins.adapter.anthropic.AsyncAnthropic" + ) as mock_anthropic: + mock_client = AsyncMock() + mock_client.messages.create.return_value = mock_message + mock_anthropic.return_value = mock_client + + # Disable streaming for this test + anthropic_adapter.preset.config.stream = False + + results = [] + async for result in anthropic_adapter.call_api(simple_messages): + results.append(result) + + assert len(results) == 2 + assert results[0] == "Hello there!" + assert isinstance(results[1], UniResponse) + assert results[1].content == "Hello there!" + assert results[1].usage is not None + assert results[1].usage.prompt_tokens == 10 + assert results[1].usage.completion_tokens == 5 + + @pytest.mark.asyncio + async def test_call_api_streaming(self, anthropic_adapter, simple_messages): + """Test call_api with streaming Anthropic response""" + from anthropic.types import TextBlock, Usage + + # Create a simple async generator to simulate text_stream + async def mock_text_stream(): + yield "Hello" + yield " there!" + + # Create a mock final message + mock_final_message = MagicMock() + mock_final_message.content = [TextBlock(text="Hello there!", type="text")] + mock_final_message.usage = Usage( + input_tokens=10, output_tokens=5, total_tokens=15 + ) + + # Create a simple mock context manager + class SimpleMockContext: + def __init__(self, text_stream_gen, final_msg): + self._text_stream = text_stream_gen + self._final_msg = final_msg + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + @property + def text_stream(self): + return self._text_stream() + + async def get_final_message(self): + return self._final_msg + + # Directly mock the stream method to return our context manager (not async!) + def mock_stream_method(**kwargs): + return SimpleMockContext(mock_text_stream, mock_final_message) + + with patch( + "amrita_core.builtins.adapter.anthropic.AsyncAnthropic" + ) as mock_anthropic: + mock_client = AsyncMock() + mock_client.messages.stream = mock_stream_method + mock_anthropic.return_value = mock_client + + # Enable streaming for this test + anthropic_adapter.preset.config.stream = True + + results = [] + async for result in anthropic_adapter.call_api(simple_messages): + results.append(result) + + assert len(results) == 3 + assert results[0] == "Hello" + assert results[1] == " there!" + assert isinstance(results[2], UniResponse) + assert results[2].content == "Hello there!" + assert results[2].usage is not None + assert results[2].usage.prompt_tokens == 10 + assert results[2].usage.completion_tokens == 5 + + @pytest.mark.asyncio + async def test_call_api_streaming_empty_content( + self, anthropic_adapter, simple_messages + ): + """Test call_api with streaming Anthropic response that has empty content""" + from anthropic.types import TextBlock, Usage + + async def mock_text_stream(): + yield "" + + mock_final_message = MagicMock() + mock_final_message.content = [TextBlock(text="", type="text")] + mock_final_message.usage = Usage( + input_tokens=10, output_tokens=0, total_tokens=10 + ) + + class SimpleMockContext: + def __init__(self, text_stream_gen, final_msg): + self._text_stream = text_stream_gen + self._final_msg = final_msg + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + @property + def text_stream(self): + return self._text_stream() + + async def get_final_message(self): + return self._final_msg + + def mock_stream_method(**kwargs): + return SimpleMockContext(mock_text_stream, mock_final_message) + + with patch( + "amrita_core.builtins.adapter.anthropic.AsyncAnthropic" + ) as mock_anthropic: + mock_client = AsyncMock() + mock_client.messages.stream = mock_stream_method + mock_anthropic.return_value = mock_client + + anthropic_adapter.preset.config.stream = True + + results = [] + async for result in anthropic_adapter.call_api(simple_messages): + results.append(result) + + assert len(results) == 2 + assert results[0] == "" + assert isinstance(results[1], UniResponse) + assert results[1].content == "" + assert results[1].usage is not None + + @pytest.mark.asyncio + async def test_call_api_non_streaming_multiple_blocks( + self, anthropic_adapter, simple_messages + ): + """Test call_api with non-streaming response containing multiple content blocks""" + from anthropic.types import TextBlock, Usage + + # Mock Anthropic client response with multiple text blocks + mock_message = MagicMock() + mock_message.content = [ + TextBlock(text="First part", type="text"), + TextBlock(text=" Second part", type="text"), + ] + mock_message.usage = Usage(input_tokens=15, output_tokens=10, total_tokens=25) + + with patch( + "amrita_core.builtins.adapter.anthropic.AsyncAnthropic" + ) as mock_anthropic: + mock_client = AsyncMock() + mock_client.messages.create.return_value = mock_message + mock_anthropic.return_value = mock_client + + anthropic_adapter.preset.config.stream = False + + results = [] + async for result in anthropic_adapter.call_api(simple_messages): + results.append(result) + + assert len(results) == 2 + assert results[0] == "First part Second part" + assert isinstance(results[1], UniResponse) + assert results[1].content == "First part Second part" + assert results[1].usage.prompt_tokens == 15 + assert results[1].usage.completion_tokens == 10 + + @pytest.mark.asyncio + async def test_call_tools_basic(self, anthropic_adapter, messages_with_tool_calls): + """Test call_tools with basic tool call scenario""" + # This would require mocking the Anthropic client's tool calling behavior + # For now, we'll focus on the message conversion parts which are the core logic + + # Test the message conversion that happens before calling the API + converted_messages = AnthropicAdapter._convert_messages( + messages_with_tool_calls + ) + assert len(converted_messages) == 3 + assert converted_messages[0]["role"] == "user" + assert converted_messages[1]["role"] == "assistant" + assert ( + converted_messages[2]["role"] == "user" + ) # tool results merged into user message + + # The actual API call would be tested with proper mocking in a real implementation + # For coverage purposes, we've tested all the helper methods thoroughly diff --git a/tests/test_weakcache.py b/tests/test_weakcache.py new file mode 100644 index 0000000..16d1479 --- /dev/null +++ b/tests/test_weakcache.py @@ -0,0 +1,510 @@ +import gc +from typing import Any + +import pytest + +from amrita_core.weakcache import WeakValueLRUCache + + +class TestObject: + """Test object for weak reference testing""" + + def __init__(self, value: Any): + self.value = value + + def __repr__(self): + return f"TestObject({self.value!r})" + + def __eq__(self, other): + if isinstance(other, TestObject): + return self.value == other.value + return False + + def __hash__(self): + return hash(self.value) + + +class TestWeakValueLRUCache: + """Test suite for WeakValueLRUCache""" + + def test_init_with_negative_capacity(self): + """Test that negative capacity raises ValueError""" + with pytest.raises(ValueError, match="Capacity must be a positive integer"): + WeakValueLRUCache(capacity=-1) + + def test_init_with_zero_capacity(self): + """Test initialization with zero capacity""" + cache = WeakValueLRUCache(capacity=0) + assert cache.capacity == 0 + assert len(cache) == 0 + + def test_init_with_positive_capacity(self): + """Test normal initialization""" + cache = WeakValueLRUCache(capacity=5) + assert cache.capacity == 5 + assert len(cache) == 0 + + def test_init_with_items(self): + """Test initialization with initial items""" + obj1 = TestObject("value1") + obj2 = TestObject("value2") + items = {"key1": obj1, "key2": obj2} + cache = WeakValueLRUCache(capacity=5, items=items) + + assert cache.get("key1") is obj1 + assert cache.get("key2") is obj2 + assert len(cache) == 2 + + def test_put_none_value_raises_error(self): + """Test that putting None value raises ValueError""" + cache = WeakValueLRUCache(capacity=5) + with pytest.raises( + ValueError, match="Cannot store None value in WeakValueLRUCache" + ): + cache.put("key", None) + + def test_put_and_get_basic(self): + """Test basic put and get operations""" + cache = WeakValueLRUCache(capacity=5) + obj = TestObject("test") + + cache.put("key", obj) + retrieved = cache.get("key") + + assert retrieved is obj + assert len(cache) == 1 + + def test_get_nonexistent_key(self): + """Test getting non-existent key returns None or default""" + cache = WeakValueLRUCache(capacity=5) + + # Default behavior + assert cache.get("nonexistent") is None + + # With default value + assert cache.get("nonexistent", "default") == "default" + + def test_get_expired_key(self): + """Test getting expired (garbage collected) key""" + cache = WeakValueLRUCache(capacity=5) + + # Create an object and add it to cache + obj = TestObject("test") + cache.put("key", obj) + assert cache.get("key") is obj + + # Delete the object and force garbage collection + del obj + gc.collect() + + # The key should now return None/default + assert cache.get("key") is None + assert cache.get("key", "default") == "default" + assert len(cache) == 0 # Should be cleaned up + + def test_lru_eviction_normal_mode(self): + """Test LRU eviction in normal mode""" + cache = WeakValueLRUCache(capacity=2) + + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + assert len(cache) == 2 + + # Adding third item should evict the oldest (key1) + cache.put("key3", obj3) + assert len(cache) == 2 + assert cache.get("key1") is None # Should be evicted + assert cache.get("key2") is obj2 + assert cache.get("key3") is obj3 + + def test_lru_eviction_with_existing_key(self): + """Test that updating existing key doesn't cause eviction""" + cache = WeakValueLRUCache(capacity=2) + + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + assert len(cache) == 2 + + # Updating existing key should move it to end (most recent) + cache.put("key1", obj1) + assert len(cache) == 2 + + # Adding new key should evict key2 (now oldest) + cache.put("key3", obj3) + assert len(cache) == 2 + assert cache.get("key1") is obj1 + assert cache.get("key2") is None # Should be evicted + assert cache.get("key3") is obj3 + + def test_loose_mode_enabled(self): + """Test loose mode behavior""" + cache = WeakValueLRUCache(capacity=2, loose_mode=True) + + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + assert len(cache) == 2 + + # In loose mode, adding third item should not evict if all refs are alive + cache.put("key3", obj3) + assert len(cache) == 3 # Should exceed capacity in loose mode + assert cache.get("key1") is obj1 + assert cache.get("key2") is obj2 + assert cache.get("key3") is obj3 + + def test_loose_mode_with_expired_refs(self): + """Test loose mode with expired references""" + cache = WeakValueLRUCache(capacity=2, loose_mode=True) + + obj1 = TestObject("1") + obj2 = TestObject("2") + + cache.put("key1", obj1) + cache.put("key2", obj2) + assert len(cache) == 2 + + # Let obj1 go out of scope and be garbage collected + del obj1 + gc.collect() + + # Add new item - should clean up expired ref and add new one + obj3 = TestObject("3") + cache.put("key3", obj3) + assert len(cache) == 2 + assert cache.get("key1") is None # Expired + assert cache.get("key2") is obj2 + assert cache.get("key3") is obj3 + + def test_resize_capacity(self): + """Test resizing cache capacity""" + cache = WeakValueLRUCache(capacity=5) + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + cache.put("key3", obj3) + assert cache.capacity == 5 + assert len(cache) == 3 + + # Resize to smaller capacity + cache.resize(2) + assert cache.capacity == 2 + + # Cache should still contain all items until next put operation + assert len(cache) == 3 + + # Adding new item should trigger eviction based on new capacity + obj4 = TestObject("4") + cache.put("key4", obj4) + assert len(cache) == 2 # Should evict based on new capacity of 2 + + def test_set_loose_mode(self): + """Test setting loose mode dynamically""" + cache = WeakValueLRUCache(capacity=2, loose_mode=False) + assert cache.loose is False + + cache.set_loose(True) + assert cache.loose is True + + cache.set_loose(False) + assert cache.loose is False + + def test_expire_method(self): + """Test expire method""" + cache = WeakValueLRUCache(capacity=5) + + # Add some objects + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + obj4 = TestObject("4") + obj5 = TestObject("5") + + cache.put("key1", obj1) + cache.put("key2", obj2) + cache.put("key3", obj3) + cache.put("key4", obj4) + cache.put("key5", obj5) + + # Let some objects expire + del obj1 + del obj3 + del obj5 + gc.collect() + + # Before expire, cache should have 5 entries (including expired) + assert len(cache) == 5 + + # Expire with specific length - should check first 'length' keys + cache.expire(3) + # After expire, the expired keys in the first 3 positions should be removed + # key1 and key3 were expired, so they should be removed from first 3 + assert "key1" not in cache._cache # Should be removed + assert "key3" not in cache._cache # Should be removed + assert "key2" in cache._cache # Should remain (not expired) + + # Test expire with default length (1/5 of cache) + # Reset cache + cache.clear() + for i in range(10): + obj = TestObject(f"value{i}") + cache.put(f"key{i}", obj) + + # Expire some objects + for i in [0, 1, 2]: + # We can't easily delete specific objects here, so just test the method call + pass + + cache.expire() # Should not crash + + def test_magic_methods_getitem_setitem_delitem(self): + """Test __getitem__, __setitem__, __delitem__""" + cache = WeakValueLRUCache(capacity=5) + obj = TestObject("test") + + # Test __setitem__ + cache["key"] = obj + assert cache["key"] is obj + + # Test __getitem__ with missing key + with pytest.raises(KeyError): + _ = cache["nonexistent"] + + # Test __delitem__ + del cache["key"] + with pytest.raises(KeyError): + del cache["key"] # Already deleted + + # Test __delitem__ with non-existent key + with pytest.raises(KeyError): + del cache["nonexistent"] + + def test_contains_method(self): + """Test __contains__ method""" + cache = WeakValueLRUCache(capacity=5) + obj = TestObject("test") + + cache.put("key", obj) + assert "key" in cache + + # Test with expired key + del obj + gc.collect() + assert "key" not in cache + + # Test with non-existent key + assert "nonexistent" not in cache + + def test_len_method(self): + """Test __len__ method""" + cache = WeakValueLRUCache(capacity=5) + assert len(cache) == 0 + + obj1 = TestObject("1") + obj2 = TestObject("2") + cache.put("key1", obj1) + cache.put("key2", obj2) + assert len(cache) == 2 + + # Test with expired object + del obj1 + gc.collect() + # Note: __len__ returns total entries including expired ones + # until they are accessed/cleaned up + assert len(cache) == 2 # Still 2 until cleanup happens + + # Accessing should trigger cleanup + _ = cache.get("key1") + assert len(cache) == 1 + + def test_iter_keys_methods(self): + """Test __iter__ and keys methods""" + cache = WeakValueLRUCache(capacity=5) + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + cache.put("key3", obj3) + + keys_list = list(cache) + assert set(keys_list) == {"key1", "key2", "key3"} + + keys_from_keys_method = list(cache.keys()) + assert set(keys_from_keys_method) == {"key1", "key2", "key3"} + + # Test with expired object + del obj2 + gc.collect() + + keys_after_gc = list(cache) + assert set(keys_after_gc) == {"key1", "key3"} + + def test_clear_method(self): + """Test clear method""" + cache = WeakValueLRUCache(capacity=5) + obj1 = TestObject("1") + obj2 = TestObject("2") + + cache.put("key1", obj1) + cache.put("key2", obj2) + assert len(cache) == 2 + + cache.clear() + assert len(cache) == 0 + assert cache.get("key1") is None + assert cache.get("key2") is None + + def test_capacity_properties(self): + """Test capacity property and get_capacity method""" + cache = WeakValueLRUCache(capacity=10) + assert cache.capacity == 10 + assert cache.get_capacity() == 10 + + def test_size_method(self): + """Test size method (counts only non-expired items)""" + cache = WeakValueLRUCache(capacity=5) + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + cache.put("key3", obj3) + assert cache.size() == 3 + + # Expire one object + del obj2 + gc.collect() + assert cache.size() == 2 + + def test_is_full_method(self): + """Test is_full method""" + cache = WeakValueLRUCache(capacity=2) + assert not cache.is_full() + + obj1 = TestObject("1") + cache.put("key1", obj1) + assert not cache.is_full() + + obj2 = TestObject("2") + cache.put("key2", obj2) + assert cache.is_full() + + # Even with expired objects, is_full checks total entries + del obj1 + gc.collect() + assert cache.is_full() # Still has 2 entries (one expired) + + def test_pop_method(self): + """Test pop method""" + cache = WeakValueLRUCache(capacity=5) + obj = TestObject("test") + + cache.put("key", obj) + + # Pop existing key + popped = cache.pop("key") + assert popped is obj + assert len(cache) == 0 + + # Pop non-existent key without default + with pytest.raises(KeyError): + cache.pop("nonexistent") + + # Pop non-existent key with default + assert cache.pop("nonexistent", "default") == "default" + + # Test pop with expired key + obj2 = TestObject("test2") + cache.put("key2", obj2) + del obj2 + gc.collect() + + # Pop expired key without default + with pytest.raises(KeyError): + cache.pop("key2") + + # Pop expired key with default + assert cache.pop("key2", "default") == "default" + + def test_repr_method(self): + """Test __repr__ method""" + cache = WeakValueLRUCache(capacity=5) + obj1 = TestObject("1") + obj2 = TestObject("2") + + cache.put("key1", obj1) + cache.put("key2", obj2) + + repr_str = repr(cache) + assert "capacity=5" in repr_str + assert "'key1': TestObject('1')" in repr_str + assert "'key2': TestObject('2')" in repr_str + + # Test with expired object + del obj1 + gc.collect() + repr_str_after_gc = repr(cache) + assert "'key1'" not in repr_str_after_gc + assert "'key2': TestObject('2')" in repr_str_after_gc + + def test_get_moves_to_end(self): + """Test that get operation moves item to end (most recently used)""" + cache = WeakValueLRUCache(capacity=3) + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + cache.put("key3", obj3) + + # Get key1 to move it to end + _ = cache.get("key1") + + # Add new item - should evict key2 (now oldest) + obj4 = TestObject("4") + cache.put("key4", obj4) + + assert cache.get("key1") is obj1 # Should still be there + assert cache.get("key2") is None # Should be evicted + assert cache.get("key3") is obj3 + assert cache.get("key4") is obj4 + + def test_put_existing_key_moves_to_end(self): + """Test that putting existing key moves it to end""" + cache = WeakValueLRUCache(capacity=3) + obj1 = TestObject("1") + obj2 = TestObject("2") + obj3 = TestObject("3") + + cache.put("key1", obj1) + cache.put("key2", obj2) + cache.put("key3", obj3) + + # Put existing key1 again + cache.put("key1", obj1) + + # Add new item - should evict key2 (now oldest) + obj4 = TestObject("4") + cache.put("key4", obj4) + + assert cache.get("key1") is obj1 + assert cache.get("key2") is None # Should be evicted + assert cache.get("key3") is obj3 + assert cache.get("key4") is obj4 diff --git a/uv.lock b/uv.lock index 07bde28..27cab1e 100644 --- a/uv.lock +++ b/uv.lock @@ -181,7 +181,7 @@ wheels = [ [[package]] name = "amrita-core" -version = "0.8.3" +version = "0.8.5" source = { editable = "." } dependencies = [ { name = "aiofiles" },