diff --git a/sdk-python/agentcube/__init__.py b/sdk-python/agentcube/__init__.py index 36c2b70c..55a050d1 100644 --- a/sdk-python/agentcube/__init__.py +++ b/sdk-python/agentcube/__init__.py @@ -14,5 +14,12 @@ from .code_interpreter import CodeInterpreterClient from .agent_runtime import AgentRuntimeClient +from .async_code_interpreter import AsyncCodeInterpreterClient +from .async_agent_runtime import AsyncAgentRuntimeClient -__all__ = ["CodeInterpreterClient", "AgentRuntimeClient"] +__all__ = [ + "CodeInterpreterClient", + "AgentRuntimeClient", + "AsyncCodeInterpreterClient", + "AsyncAgentRuntimeClient", +] diff --git a/sdk-python/agentcube/async_agent_runtime.py b/sdk-python/agentcube/async_agent_runtime.py new file mode 100644 index 00000000..7059a3d1 --- /dev/null +++ b/sdk-python/agentcube/async_agent_runtime.py @@ -0,0 +1,125 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, Dict, Optional + +from agentcube.clients.async_agent_runtime_data_plane import AsyncAgentRuntimeDataPlaneClient +from agentcube.utils.log import get_logger + + +class AsyncAgentRuntimeClient: + """Async client for invoking AgentRuntime services. + + Usage:: + + # Async context manager (recommended) + async with AsyncAgentRuntimeClient(agent_name="my-agent", ...) as client: + result = await client.invoke({"input": "hello"}) + + # Manual lifecycle management + client = AsyncAgentRuntimeClient(agent_name="my-agent", ...) + await client.start() + try: + result = await client.invoke({"input": "hello"}) + finally: + await client.close() + """ + + def __init__( + self, + agent_name: str, + namespace: str = "default", + router_url: Optional[str] = None, + verbose: bool = False, + session_id: Optional[str] = None, + timeout: int = 120, + connect_timeout: float = 5.0, + ): + self.agent_name = agent_name + self.namespace = namespace + self.timeout = timeout + self.connect_timeout = connect_timeout + + level = logging.DEBUG if verbose else logging.INFO + self.logger = get_logger(__name__, level=level) + + router_url = router_url or os.getenv("ROUTER_URL") + if not router_url: + raise ValueError( + "Router URL for Data Plane communication must be provided via " + "'router_url' argument or 'ROUTER_URL' environment variable." + ) + self.router_url = router_url + + self.session_id: Optional[str] = session_id + self.dp_client = AsyncAgentRuntimeDataPlaneClient( + router_url=self.router_url, + namespace=self.namespace, + agent_name=self.agent_name, + timeout=self.timeout, + connect_timeout=self.connect_timeout, + ) + if verbose: + self.dp_client.logger.setLevel(logging.DEBUG) + + async def start(self) -> None: + """Bootstrap the session ID if not already set.""" + if not self.session_id: + self.logger.info("Bootstrapping AgentRuntime session...") + self.session_id = await self.dp_client.bootstrap_session_id() + self.logger.info(f"AgentRuntime session created: {self.session_id}") + else: + self.logger.info(f"Reusing AgentRuntime session: {self.session_id}") + + async def __aenter__(self) -> "AsyncAgentRuntimeClient": + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() + + async def invoke( + self, payload: Dict[str, Any], timeout: Optional[float] = None + ) -> Any: + """Invoke the agent runtime with a payload. + + Args: + payload: The request payload to send. + timeout: Optional per-request timeout in seconds. + + Returns: + The parsed JSON response, or the raw text if the response is not JSON. + """ + if not self.session_id: + raise ValueError("AgentRuntime session_id is not initialized; call start() first.") + + resp = await self.dp_client.invoke( + session_id=self.session_id, + payload=payload, + timeout=timeout, + ) + resp.raise_for_status() + try: + return resp.json() + except ValueError: + # httpx raises ValueError (json.JSONDecodeError subclass) when the + # response body is not valid JSON; fall back to returning raw text. + return resp.text + + async def close(self) -> None: + """Close the underlying HTTP session.""" + if self.dp_client: + await self.dp_client.close() diff --git a/sdk-python/agentcube/async_code_interpreter.py b/sdk-python/agentcube/async_code_interpreter.py new file mode 100644 index 00000000..c4deeabb --- /dev/null +++ b/sdk-python/agentcube/async_code_interpreter.py @@ -0,0 +1,262 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Any, List, Optional + +from agentcube.clients.async_control_plane import AsyncControlPlaneClient +from agentcube.clients.async_code_interpreter_data_plane import ( + AsyncCodeInterpreterDataPlaneClient, +) +from agentcube.utils.log import get_logger + + +class AsyncCodeInterpreterClient: + """ + Async AgentCube Code Interpreter Client. + + Manages the lifecycle of a Code Interpreter session and exposes async + methods to execute code and manage files within it. + + Because session creation involves network I/O, this client **must** be + initialised with ``await AsyncCodeInterpreterClient.create(...)`` or used + as an async context manager:: + + # Async context manager (recommended) + async with AsyncCodeInterpreterClient.create(...) as client: + await client.run_code("python", "print('hello')") + + # Manual lifecycle management + client = await AsyncCodeInterpreterClient.create(...) + try: + await client.run_code("python", "print('hello')") + finally: + await client.stop() + """ + + def __init__( + self, + name: str = "my-interpreter", + namespace: str = "default", + ttl: int = 3600, + workload_manager_url: Optional[str] = None, + router_url: Optional[str] = None, + auth_token: Optional[str] = None, + verbose: bool = False, + session_id: Optional[str] = None, + ): + """Store configuration; does *not* create a session. + + Call ``await AsyncCodeInterpreterClient.create(...)`` instead of + constructing this class directly. + """ + self.name = name + self.namespace = namespace + self.ttl = ttl + self.verbose = verbose + + level = logging.DEBUG if verbose else logging.INFO + self.logger = get_logger(__name__, level=level) + + self.cp_client = AsyncControlPlaneClient(workload_manager_url, auth_token) + if verbose: + self.cp_client.logger.setLevel(logging.DEBUG) + + router_url = router_url or os.getenv("ROUTER_URL") + if not router_url: + raise ValueError( + "Router URL for Data Plane communication must be provided via " + "'router_url' argument or 'ROUTER_URL' environment variable." + ) + self.router_url = router_url + + self.session_id: Optional[str] = session_id + self.dp_client: Optional[AsyncCodeInterpreterDataPlaneClient] = None + + # ------------------------------------------------------------------ + # Async factory / context-manager helpers + # ------------------------------------------------------------------ + + @classmethod + async def create( + cls, + name: str = "my-interpreter", + namespace: str = "default", + ttl: int = 3600, + workload_manager_url: Optional[str] = None, + router_url: Optional[str] = None, + auth_token: Optional[str] = None, + verbose: bool = False, + session_id: Optional[str] = None, + ) -> "AsyncCodeInterpreterClient": + """Create and fully initialise an AsyncCodeInterpreterClient. + + This is the preferred way to create a client when you are not using + the async context manager. + """ + instance = cls( + name=name, + namespace=namespace, + ttl=ttl, + workload_manager_url=workload_manager_url, + router_url=router_url, + auth_token=auth_token, + verbose=verbose, + session_id=session_id, + ) + await instance._async_init() + return instance + + async def _async_init(self) -> None: + """Perform async initialisation (session creation / reuse).""" + if self.session_id: + self.logger.info(f"Reusing existing session: {self.session_id}") + self._init_data_plane() + else: + self.logger.info("Creating new session...") + self.session_id = await self.cp_client.create_session( + name=self.name, + namespace=self.namespace, + ttl=self.ttl, + ) + self.logger.info(f"Session created: {self.session_id}") + try: + self._init_data_plane() + except Exception: + self.logger.warning( + f"Failed to initialize data plane client, " + f"deleting session {self.session_id} to prevent resource leak" + ) + await self.cp_client.delete_session(self.session_id) + self.session_id = None + raise + + def _init_data_plane(self) -> None: + """Initialise the async Data Plane client (sync — no network I/O).""" + self.dp_client = AsyncCodeInterpreterDataPlaneClient( + cr_name=self.name, + router_url=self.router_url, + namespace=self.namespace, + session_id=self.session_id, + ) + if self.verbose: + self.dp_client.logger.setLevel(logging.DEBUG) + + async def __aenter__(self) -> "AsyncCodeInterpreterClient": + await self._async_init() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.stop() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def stop(self) -> None: + """Stop and delete the session, releasing all resources.""" + if self.dp_client: + await self.dp_client.close() + self.dp_client = None + + if self.session_id: + self.logger.info(f"Deleting session {self.session_id}...") + await self.cp_client.delete_session(self.session_id) + self.session_id = None + + await self.cp_client.close() + + # ------------------------------------------------------------------ + # Data Plane methods + # ------------------------------------------------------------------ + + async def execute_command( + self, command: str, timeout: Optional[float] = None + ) -> str: + """Execute a shell command. + + Args: + command: The shell command to execute. + timeout: Maximum time in seconds to allow for command execution. + + Returns: + The stdout output of the command. + """ + if not self.dp_client: + raise RuntimeError("Data Plane client is not initialized.") + return await self.dp_client.execute_command(command, timeout) + + async def run_code( + self, language: str, code: str, timeout: Optional[float] = None + ) -> str: + """Execute a code snippet in the remote environment. + + Args: + language: The programming language (e.g. ``"python"``, ``"bash"``). + code: The code snippet to execute. + timeout: Optional maximum execution time in seconds. + + Returns: + The stdout generated by the code execution. + """ + if not self.dp_client: + raise RuntimeError("Data Plane client is not initialized.") + return await self.dp_client.run_code(language, code, timeout) + + async def write_file(self, content: str, remote_path: str) -> None: + """Write content to a file in the remote environment. + + Args: + content: The string content to write. + remote_path: Destination path in the remote environment. + """ + if not self.dp_client: + raise RuntimeError("Data Plane client is not initialized.") + await self.dp_client.write_file(content, remote_path) + + async def upload_file(self, local_path: str, remote_path: str) -> None: + """Upload a local file to the remote environment. + + Args: + local_path: Path to the file on the local filesystem. + remote_path: Destination path in the remote environment. + """ + if not self.dp_client: + raise RuntimeError("Data Plane client is not initialized.") + await self.dp_client.upload_file(local_path, remote_path) + + async def download_file(self, remote_path: str, local_path: str) -> None: + """Download a file from the remote environment to the local filesystem. + + Args: + remote_path: Path to the file in the remote environment. + local_path: Destination path on the local filesystem. + """ + if not self.dp_client: + raise RuntimeError("Data Plane client is not initialized.") + await self.dp_client.download_file(remote_path, local_path) + + async def list_files(self, path: str = ".") -> List[Any]: + """List files and directories in a path in the remote environment. + + Args: + path: Directory path to list (default ``"."``). + + Returns: + A list of file/directory information dicts. + """ + if not self.dp_client: + raise RuntimeError("Data Plane client is not initialized.") + return await self.dp_client.list_files(path) diff --git a/sdk-python/agentcube/clients/__init__.py b/sdk-python/agentcube/clients/__init__.py index a188a464..377c09e3 100644 --- a/sdk-python/agentcube/clients/__init__.py +++ b/sdk-python/agentcube/clients/__init__.py @@ -15,10 +15,16 @@ from .control_plane import ControlPlaneClient from .code_interpreter_data_plane import CodeInterpreterDataPlaneClient from .agent_runtime_data_plane import AgentRuntimeDataPlaneClient +from .async_control_plane import AsyncControlPlaneClient +from .async_code_interpreter_data_plane import AsyncCodeInterpreterDataPlaneClient +from .async_agent_runtime_data_plane import AsyncAgentRuntimeDataPlaneClient __all__ = [ "ControlPlaneClient", "CodeInterpreterDataPlaneClient", "AgentRuntimeDataPlaneClient", + "AsyncControlPlaneClient", + "AsyncCodeInterpreterDataPlaneClient", + "AsyncAgentRuntimeDataPlaneClient", ] diff --git a/sdk-python/agentcube/clients/async_agent_runtime_data_plane.py b/sdk-python/agentcube/clients/async_agent_runtime_data_plane.py new file mode 100644 index 00000000..194cc34b --- /dev/null +++ b/sdk-python/agentcube/clients/async_agent_runtime_data_plane.py @@ -0,0 +1,103 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional +from urllib.parse import urljoin + +import httpx + +from agentcube.utils.async_http import create_async_session +from agentcube.utils.log import get_logger + + +class AsyncAgentRuntimeDataPlaneClient: + SESSION_HEADER = "x-agentcube-session-id" + + def __init__( + self, + router_url: str, + namespace: str, + agent_name: str, + timeout: int = 120, + connect_timeout: float = 5.0, + connector_limit: int = 100, + connector_limit_per_host: int = 10, + ): + self.router_url = router_url + self.namespace = namespace + self.agent_name = agent_name + self.timeout = timeout + self.connect_timeout = connect_timeout + self.logger = get_logger(f"{__name__}.AsyncAgentRuntimeDataPlaneClient") + + base_path = ( + f"/v1/namespaces/{namespace}/agent-runtimes/{agent_name}/invocations/" + ) + self.base_url = urljoin(router_url, base_path) + + self._http_session = create_async_session( + connector_limit=connector_limit, + connector_limit_per_host=connector_limit_per_host, + ) + + def _make_timeout(self, read_timeout: Optional[float] = None) -> httpx.Timeout: + """Build an httpx.Timeout with the given read timeout.""" + return httpx.Timeout( + read_timeout if read_timeout is not None else self.timeout, + connect=self.connect_timeout, + ) + + async def bootstrap_session_id(self) -> str: + """Send a GET to the base URL to obtain a session ID from the response header.""" + resp = await self._http_session.get( + self.base_url, timeout=self._make_timeout() + ) + resp.raise_for_status() + session_id = resp.headers.get(self.SESSION_HEADER) + + if not session_id: + raise ValueError( + f"Missing required response header: {self.SESSION_HEADER}" + ) + return session_id + + async def invoke( + self, + session_id: str, + payload: Dict[str, Any], + timeout: Optional[float] = None, + ) -> httpx.Response: + """Invoke the agent runtime with a payload.""" + headers = { + self.SESSION_HEADER: session_id, + "Content-Type": "application/json", + } + t = self._make_timeout(timeout) + self.logger.debug(f"POST {self.base_url}") + return await self._http_session.post( + self.base_url, + json=payload, + headers=headers, + timeout=t, + ) + + async def close(self) -> None: + """Close the underlying HTTP session.""" + await self._http_session.aclose() + + async def __aenter__(self) -> "AsyncAgentRuntimeDataPlaneClient": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/sdk-python/agentcube/clients/async_code_interpreter_data_plane.py b/sdk-python/agentcube/clients/async_code_interpreter_data_plane.py new file mode 100644 index 00000000..2fe242b2 --- /dev/null +++ b/sdk-python/agentcube/clients/async_code_interpreter_data_plane.py @@ -0,0 +1,263 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ast +import asyncio +import base64 +import json +import os +import shlex +import uuid +from typing import Any, List, Optional, Union +from urllib.parse import urljoin + +import httpx + +from agentcube.exceptions import CommandExecutionError +from agentcube.utils.async_http import create_async_session +from agentcube.utils.log import get_logger + +# Extra seconds added to the HTTP read timeout on top of the PicoD command +# timeout. This gives PicoD time to finish and return its JSON response +# before httpx gives up waiting. +_TIMEOUT_BUFFER_SECONDS = 2.0 + + +def _write_bytes(path: str, data: bytes) -> None: + """Write bytes to a file; intended to be called via asyncio.to_thread.""" + with open(path, "wb") as f: + f.write(data) + + +class AsyncCodeInterpreterDataPlaneClient: + """Async client for AgentCube Data Plane (Router -> PicoD). + Handles command execution and file operations via the Router. + """ + + def __init__( + self, + session_id: str, + router_url: Optional[str] = None, + namespace: Optional[str] = None, + cr_name: Optional[str] = None, + base_url: Optional[str] = None, + timeout: int = 120, + connect_timeout: float = 5.0, + connector_limit: int = 100, + connector_limit_per_host: int = 10, + ): + """Initialize the async Data Plane client. + + Args: + session_id: Session ID (for x-agentcube-session-id header). + router_url: Base URL of the Router service (optional if base_url is provided). + namespace: Kubernetes namespace (optional if base_url is provided). + cr_name: Code Interpreter resource name (optional if base_url is provided). + base_url: Direct base URL for invocations (overrides router logic). + timeout: Default request timeout in seconds (default: 120). + connect_timeout: Connection timeout in seconds (default: 5). + connector_limit: Total simultaneous connections (default: 100). + connector_limit_per_host: Max keepalive connections per host (default: 10). + """ + self.session_id = session_id + self.timeout = timeout + self.connect_timeout = connect_timeout + self.logger = get_logger(f"{__name__}.AsyncCodeInterpreterDataPlaneClient") + + if base_url: + self.base_url = base_url + self.cr_name = cr_name + elif router_url and namespace and cr_name: + self.cr_name = cr_name + base_path = f"/v1/namespaces/{namespace}/code-interpreters/{cr_name}/invocations/" + self.base_url = urljoin(router_url, base_path) + else: + raise ValueError( + "Either 'base_url' or all of 'router_url', 'namespace', 'cr_name' must be provided." + ) + + self._http_session = create_async_session( + connector_limit=connector_limit, + connector_limit_per_host=connector_limit_per_host, + ) + self._http_session.headers.update({"x-agentcube-session-id": self.session_id}) + + def _make_timeout(self, read_timeout: Optional[float] = None) -> httpx.Timeout: + """Build an httpx.Timeout with the given read timeout.""" + return httpx.Timeout( + read_timeout if read_timeout is not None else self.timeout, + connect=self.connect_timeout, + ) + + async def _request( + self, + method: str, + endpoint: str, + body: Optional[bytes] = None, + timeout: Optional[httpx.Timeout] = None, + **kwargs, + ) -> httpx.Response: + """Make a request to the Data Plane via Router.""" + url = urljoin(self.base_url, endpoint) + if timeout is None: + timeout = self._make_timeout() + + extra_headers = kwargs.pop("headers", {}) + if body: + extra_headers.setdefault("Content-Type", "application/json") + + self.logger.debug(f"{method} {url}") + + return await self._http_session.request( + method=method, + url=url, + content=body, + headers=extra_headers, + timeout=timeout, + **kwargs, + ) + + async def execute_command( + self, command: Union[str, List[str]], timeout: Optional[float] = None + ) -> str: + """Execute a shell command. + + Args: + command: The command to execute, as a string or list of arguments. + timeout: Optional timeout for the command execution. + + Returns: + The stdout output of the command. + """ + timeout_value = timeout if timeout is not None else self.timeout + timeout_str = ( + f"{timeout_value}s" if isinstance(timeout_value, (int, float)) else str(timeout_value) + ) + + cmd_list = shlex.split(command, posix=True) if isinstance(command, str) else command + + payload = {"command": cmd_list, "timeout": timeout_str} + body = json.dumps(payload).encode("utf-8") + + # Add a buffer so httpx doesn't time out before PicoD returns the JSON response + read_timeout = ( + timeout_value + _TIMEOUT_BUFFER_SECONDS + if isinstance(timeout_value, (int, float)) + else timeout_value + ) + t = self._make_timeout(read_timeout) + + resp = await self._request("POST", "api/execute", body=body, timeout=t) + resp.raise_for_status() + result = resp.json() + + if result["exit_code"] != 0: + raise CommandExecutionError( + exit_code=result["exit_code"], + stderr=result["stderr"], + command=command, + ) + + return result["stdout"] + + async def run_code( + self, language: str, code: str, timeout: Optional[float] = None + ) -> str: + """Run a code snippet (python or bash).""" + lang = language.lower() + if lang in ["python", "py", "python3"]: + try: + ast.parse(code) + except SyntaxError: + fixed_code = code.replace("\\n", "\n") + try: + ast.parse(fixed_code) + self.logger.warning( + "Detected and fixed double-escaped newlines in Python code." + ) + code = fixed_code + except SyntaxError: + pass + except Exception as e: + self.logger.debug(f"AST parsing fallback error: {e}", exc_info=True) + + filename = f"script-{uuid.uuid4()}.py" + await self.write_file(code, filename) + cmd: List[str] = ["python3", filename] + elif lang in ["bash", "sh"]: + filename = f"script-{uuid.uuid4()}.sh" + await self.write_file(code, filename) + cmd = ["bash", filename] + else: + raise ValueError(f"Unsupported language: {language}") + + return await self.execute_command(cmd, timeout) + + async def write_file(self, content: str, remote_path: str) -> None: + """Write text content to a file.""" + content_b64 = base64.b64encode(content.encode("utf-8")).decode("utf-8") + payload = {"path": remote_path, "content": content_b64, "mode": "0644"} + body = json.dumps(payload).encode("utf-8") + + resp = await self._request("POST", "api/files", body=body) + resp.raise_for_status() + + async def upload_file(self, local_path: str, remote_path: str) -> None: + """Upload a local file using multipart/form-data.""" + if not os.path.exists(local_path): + raise FileNotFoundError(f"Local file not found: {local_path}") + + url = urljoin(self.base_url, "api/files") + self.logger.debug(f"Uploading file {local_path} to {remote_path}") + + with open(local_path, "rb") as f: + resp = await self._http_session.post( + url, + files={"file": (os.path.basename(local_path), f)}, + data={"path": remote_path, "mode": "0644"}, + timeout=self._make_timeout(), + ) + resp.raise_for_status() + + async def download_file(self, remote_path: str, local_path: str) -> None: + """Download a file.""" + clean_path = remote_path.lstrip("/") + async with self._http_session.stream( + "GET", + urljoin(self.base_url, f"api/files/{clean_path}"), + timeout=self._make_timeout(), + ) as resp: + resp.raise_for_status() + content = await resp.aread() + + if os.path.dirname(local_path): + os.makedirs(os.path.dirname(local_path), exist_ok=True) + # Run the blocking file write in a thread so the event loop is not blocked + await asyncio.to_thread(_write_bytes, local_path, content) + + async def list_files(self, path: str = ".") -> Any: + """List files in a directory.""" + resp = await self._request("GET", "api/files", params={"path": path}) + resp.raise_for_status() + return resp.json().get("files", []) + + async def close(self) -> None: + """Close the underlying HTTP session.""" + await self._http_session.aclose() + + async def __aenter__(self) -> "AsyncCodeInterpreterDataPlaneClient": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/sdk-python/agentcube/clients/async_control_plane.py b/sdk-python/agentcube/clients/async_control_plane.py new file mode 100644 index 00000000..9980cf9c --- /dev/null +++ b/sdk-python/agentcube/clients/async_control_plane.py @@ -0,0 +1,143 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any, Dict, Optional + +import httpx + +from agentcube.utils.async_http import create_async_session +from agentcube.utils.log import get_logger +from agentcube.utils.utils import read_token_from_file + + +class AsyncControlPlaneClient: + """Async client for AgentCube Control Plane (WorkloadManager). + Handles creation and deletion of Code Interpreter sessions. + """ + + def __init__( + self, + workload_manager_url: Optional[str] = None, + auth_token: Optional[str] = None, + timeout: int = 120, + connect_timeout: float = 5.0, + connector_limit: int = 100, + connector_limit_per_host: int = 10, + ): + """Initialize the async Control Plane client. + + Args: + workload_manager_url: URL of the WorkloadManager service. + auth_token: Kubernetes Service Account Token for authentication. + timeout: Default request timeout in seconds (default: 120). + connect_timeout: Connection timeout in seconds (default: 5). + connector_limit: Total simultaneous connections (default: 100). + connector_limit_per_host: Max keepalive connections per host (default: 10). + """ + self.base_url = workload_manager_url or os.getenv("WORKLOAD_MANAGER_URL") + if not self.base_url: + raise ValueError( + "Workload Manager URL must be provided via 'workload_manager_url' argument " + "or 'WORKLOAD_MANAGER_URL' environment variable." + ) + + token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" + token = auth_token or read_token_from_file(token_path) + + self.timeout = httpx.Timeout(timeout, connect=connect_timeout) + self.logger = get_logger(f"{__name__}.AsyncControlPlaneClient") + + headers: Dict[str, str] = {"Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + + self.session = create_async_session( + connector_limit=connector_limit, + connector_limit_per_host=connector_limit_per_host, + ) + self.session.headers.update(headers) + + async def create_session( + self, + name: str = "my-interpreter", + namespace: str = "default", + metadata: Optional[Dict[str, Any]] = None, + ttl: int = 3600, + ) -> str: + """Create a new Code Interpreter session. + + Args: + name: Name of the CodeInterpreter template (CRD name). + namespace: Kubernetes namespace. + metadata: Optional metadata. + ttl: Time to live (seconds). + + Returns: + session_id (str): The ID of the created session. + """ + payload = { + "name": name, + "namespace": namespace, + "ttl": ttl, + "metadata": metadata or {}, + } + + url = f"{self.base_url}/v1/code-interpreter" + self.logger.debug(f"Creating session at {url} with payload: {payload}") + + resp = await self.session.post(url, json=payload, timeout=self.timeout) + resp.raise_for_status() + data = resp.json() + + if "sessionId" not in data or not data["sessionId"]: + self.logger.error("Response JSON missing 'sessionId' in create_session response.") + self.logger.debug(f"Full response data: {data}") + raise ValueError("Failed to create session: 'sessionId' missing from response") + return data["sessionId"] + + async def delete_session(self, session_id: str) -> bool: + """Delete a Code Interpreter session. + + Args: + session_id: The session ID to delete. + + Returns: + True if deleted successfully (or didn't exist), False on failure. + """ + url = f"{self.base_url}/v1/code-interpreter/sessions/{session_id}" + self.logger.debug(f"Deleting session {session_id} at {url}") + + try: + resp = await self.session.delete(url, timeout=self.timeout) + if resp.status_code == 404: + return True # Already gone + resp.raise_for_status() + return True + except httpx.HTTPError as e: + # httpx.HTTPError is the base for both network errors (RequestError) + # and HTTP status errors (HTTPStatusError). We treat all of them as + # non-fatal so that callers can continue without a hard crash. + self.logger.error(f"Failed to delete session {session_id}: {e}") + return False + + async def close(self) -> None: + """Close the underlying session and release connection pool resources.""" + await self.session.aclose() + + async def __aenter__(self) -> "AsyncControlPlaneClient": + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/sdk-python/agentcube/clients/code_interpreter_data_plane.py b/sdk-python/agentcube/clients/code_interpreter_data_plane.py index a9350865..8c2b9fa2 100644 --- a/sdk-python/agentcube/clients/code_interpreter_data_plane.py +++ b/sdk-python/agentcube/clients/code_interpreter_data_plane.py @@ -14,10 +14,10 @@ import base64 import json -import time import os import ast import shlex +import uuid from typing import Optional, Any, List, Union from urllib.parse import urljoin @@ -183,12 +183,12 @@ def run_code(self, language: str, code: str, timeout: Optional[float] = None) -> self.logger.debug(f"AST parsing fallback error: {e}", exc_info=True) # Use file-based execution to avoid shell quoting issues and length limits - filename = f"script_{int(time.time() * 1000)}.py" + filename = f"script-{uuid.uuid4()}.py" self.write_file(code, filename) cmd = ["python3", filename] elif lang in ["bash", "sh"]: # Also use file execution for bash to be consistent and safe - filename = f"script_{int(time.time() * 1000)}.sh" + filename = f"script-{uuid.uuid4()}.sh" self.write_file(code, filename) cmd = ["bash", filename] else: diff --git a/sdk-python/agentcube/utils/async_http.py b/sdk-python/agentcube/utils/async_http.py new file mode 100644 index 00000000..7f100690 --- /dev/null +++ b/sdk-python/agentcube/utils/async_http.py @@ -0,0 +1,37 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async HTTP session utilities for AgentCube SDK.""" + +import httpx + + +def create_async_session( + connector_limit: int = 10, + connector_limit_per_host: int = 10, +) -> httpx.AsyncClient: + """Create an httpx AsyncClient with connection pooling. + + Args: + connector_limit: Total number of simultaneous connections (default: 10). + connector_limit_per_host: Max keepalive connections per host (default: 10). + + Returns: + A configured httpx.AsyncClient with connection limits. + """ + limits = httpx.Limits( + max_connections=connector_limit, + max_keepalive_connections=connector_limit_per_host, + ) + return httpx.AsyncClient(limits=limits) diff --git a/sdk-python/examples/async_agent_runtime_usage.py b/sdk-python/examples/async_agent_runtime_usage.py new file mode 100644 index 00000000..36280c48 --- /dev/null +++ b/sdk-python/examples/async_agent_runtime_usage.py @@ -0,0 +1,51 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from agentcube import AsyncAgentRuntimeClient + + +async def main(): + # first time: it will create a new pod + async with AsyncAgentRuntimeClient( + agent_name="my-agent", + router_url="http://localhost:18081", + namespace="default", + verbose=True, + ) as client_v1: + print(client_v1.session_id) + + result_v1 = await client_v1.invoke(payload={"prompt": "Hello World!"}) + print(result_v1) + + session_id = client_v1.session_id + + # second time: reuse the pod created above + async with AsyncAgentRuntimeClient( + agent_name="my-agent", + router_url="http://localhost:18081", + namespace="default", + session_id=session_id, + verbose=True, + ) as client_v2: + # same session_id as the first time + print(client_v2.session_id) + + result_v2 = await client_v2.invoke(payload={"prompt": "Hello World!"}) + print(result_v2) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk-python/examples/async_basic_usage.py b/sdk-python/examples/async_basic_usage.py new file mode 100644 index 00000000..1e3b5737 --- /dev/null +++ b/sdk-python/examples/async_basic_usage.py @@ -0,0 +1,96 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +AgentCube SDK Async Basic Usage Example + +Demonstrates: +1. Async basic command execution and code running +2. Async file operations +3. Session reuse for file-system state persistence (useful for AI workflows) +""" + +import asyncio + +from agentcube import AsyncCodeInterpreterClient + + +async def basic_operations(): + """Demonstrate basic async SDK operations with context manager.""" + print("=== Async Basic Operations ===\n") + + async with AsyncCodeInterpreterClient(verbose=True) as client: + print(f"Session ID: {client.session_id}") + + # 1. Shell commands + print("\n--- Shell Command: whoami ---") + output = await client.execute_command("whoami") + print(f"Result: {output.strip()}") + + # 2. Python code execution + print("\n--- Python Code ---") + code = """ +import math +print(f"Pi is approximately {math.pi:.6f}") +""" + output = await client.run_code("python", code) + print(f"Result: {output.strip()}") + + # 3. File operations + print("\n--- File Operations ---") + await client.write_file("Hello from AgentCube!", "hello.txt") + files = await client.list_files(".") + print(f"Files: {[f['name'] for f in files]}") + + # Session automatically deleted on exit + print("\nSession deleted.") + + +async def session_reuse_example(): + """ + Demonstrate async session reuse for AI workflows. + + File system state persists across sessions; Python variables do not. + """ + print("\n=== Async Session Reuse (File State Persistence) ===\n") + + # Step 1: Create session and write a file + print("Step 1: Create session, write value.txt = 42") + client1 = await AsyncCodeInterpreterClient.create(verbose=True) + await client1.write_file("42", "value.txt") + session_id = client1.session_id + print(f"Session ID saved: {session_id}") + # Don't call stop() - let session persist + + # Step 2: Reuse session - file system state should still exist + print("\nStep 2: Reuse session, read value.txt") + client2 = await AsyncCodeInterpreterClient.create( + session_id=session_id, verbose=True + ) + result = await client2.run_code("python", "print(open('value.txt').read())") + print(f"Result: {result.strip()}") # Should print "42" + + # Step 3: Cleanup + print("\nStep 3: Delete session") + await client2.stop() + print("Session deleted.") + + +async def main(): + await basic_operations() + await session_reuse_example() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk-python/pyproject.toml b/sdk-python/pyproject.toml index ee6ffbaa..9eb5dba9 100644 --- a/sdk-python/pyproject.toml +++ b/sdk-python/pyproject.toml @@ -13,7 +13,8 @@ requires-python = ">=3.10" dependencies = [ "requests", "PyJWT>=2.0.0", - "cryptography" + "cryptography", + "httpx>=0.27.0", ] [tool.setuptools.packages.find] diff --git a/sdk-python/requirements.txt b/sdk-python/requirements.txt index 79123ef9..a114a4b7 100644 --- a/sdk-python/requirements.txt +++ b/sdk-python/requirements.txt @@ -1,3 +1,4 @@ requests PyJWT>=2.0.0 cryptography +httpx>=0.27.0 diff --git a/sdk-python/tests/test_async_agent_runtime.py b/sdk-python/tests/test_async_agent_runtime.py new file mode 100644 index 00000000..2a65ca41 --- /dev/null +++ b/sdk-python/tests/test_async_agent_runtime.py @@ -0,0 +1,170 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +os.environ.setdefault("ROUTER_URL", "http://mock-router:8080") + +from agentcube.async_agent_runtime import AsyncAgentRuntimeClient + + +class TestAsyncAgentRuntimeClientSessionBootstrap(unittest.IsolatedAsyncioTestCase): + @patch("agentcube.async_agent_runtime.AsyncAgentRuntimeDataPlaneClient") + async def test_start_bootstraps_session_id_when_missing(self, mock_dp_class): + mock_dp = MagicMock() + mock_dp.bootstrap_session_id = AsyncMock(return_value="sess-123") + mock_dp.logger = MagicMock() + mock_dp_class.return_value = mock_dp + + client = AsyncAgentRuntimeClient(agent_name="agent-a", router_url="http://t:1") + await client.start() + + self.assertEqual(client.session_id, "sess-123") + mock_dp.bootstrap_session_id.assert_awaited_once_with() + + @patch("agentcube.async_agent_runtime.AsyncAgentRuntimeDataPlaneClient") + async def test_start_reuses_session_id_when_provided(self, mock_dp_class): + mock_dp = MagicMock() + mock_dp.logger = MagicMock() + mock_dp_class.return_value = mock_dp + + client = AsyncAgentRuntimeClient( + agent_name="agent-a", + router_url="http://t:1", + session_id="existing-456", + ) + await client.start() + + self.assertEqual(client.session_id, "existing-456") + mock_dp.bootstrap_session_id.assert_not_called() + + +class TestAsyncAgentRuntimeClientContextManager(unittest.IsolatedAsyncioTestCase): + @patch("agentcube.async_agent_runtime.AsyncAgentRuntimeDataPlaneClient") + async def test_context_manager_starts_and_closes(self, mock_dp_class): + mock_dp = MagicMock() + mock_dp.bootstrap_session_id = AsyncMock(return_value="ctx-sess") + mock_dp.close = AsyncMock() + mock_dp.logger = MagicMock() + mock_dp_class.return_value = mock_dp + + async with AsyncAgentRuntimeClient( + agent_name="agent-a", router_url="http://t:1" + ) as client: + self.assertEqual(client.session_id, "ctx-sess") + + mock_dp.close.assert_awaited_once() + + +class TestAsyncAgentRuntimeClientInvoke(unittest.IsolatedAsyncioTestCase): + @patch("agentcube.async_agent_runtime.AsyncAgentRuntimeDataPlaneClient") + async def test_invoke_returns_json(self, mock_dp_class): + mock_dp = MagicMock() + mock_dp.bootstrap_session_id = AsyncMock(return_value="sess-789") + mock_dp.logger = MagicMock() + + # httpx responses are fully loaded — json() and text are sync + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json = MagicMock(return_value={"ok": True}) + mock_dp.invoke = AsyncMock(return_value=mock_resp) + + mock_dp_class.return_value = mock_dp + + client = AsyncAgentRuntimeClient(agent_name="agent-a", router_url="http://t:1") + await client.start() + out = await client.invoke({"input": "hi"}) + + self.assertEqual(out, {"ok": True}) + mock_dp.invoke.assert_awaited_once() + call_kwargs = mock_dp.invoke.call_args.kwargs + self.assertEqual(call_kwargs["session_id"], "sess-789") + self.assertEqual(call_kwargs["payload"], {"input": "hi"}) + + @patch("agentcube.async_agent_runtime.AsyncAgentRuntimeDataPlaneClient") + async def test_invoke_falls_back_to_text(self, mock_dp_class): + mock_dp = MagicMock() + mock_dp.bootstrap_session_id = AsyncMock(return_value="sess-999") + mock_dp.logger = MagicMock() + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json = MagicMock(side_effect=ValueError("not json")) + mock_resp.text = "plain" + mock_dp.invoke = AsyncMock(return_value=mock_resp) + + mock_dp_class.return_value = mock_dp + + client = AsyncAgentRuntimeClient(agent_name="agent-a", router_url="http://t:1") + await client.start() + out = await client.invoke({"input": "hi"}) + + self.assertEqual(out, "plain") + + +class TestAsyncAgentRuntimeDataPlaneClient(unittest.IsolatedAsyncioTestCase): + @patch("agentcube.clients.async_agent_runtime_data_plane.create_async_session") + async def test_bootstrap_session_id_extracts_header(self, mock_create): + from agentcube.clients.async_agent_runtime_data_plane import ( + AsyncAgentRuntimeDataPlaneClient, + ) + + # httpx responses are fully loaded; no async context manager needed + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.headers = {"x-agentcube-session-id": "abc"} + + mock_session = MagicMock() + mock_session.get = AsyncMock(return_value=mock_resp) + mock_create.return_value = mock_session + + client = AsyncAgentRuntimeDataPlaneClient( + router_url="http://router", + namespace="default", + agent_name="agent-a", + ) + result = await client.bootstrap_session_id() + self.assertEqual(result, "abc") + + @patch("agentcube.clients.async_agent_runtime_data_plane.create_async_session") + async def test_invoke_sends_session_header(self, mock_create): + from agentcube.clients.async_agent_runtime_data_plane import ( + AsyncAgentRuntimeDataPlaneClient, + ) + + mock_resp = MagicMock() + mock_session = MagicMock() + mock_session.post = AsyncMock(return_value=mock_resp) + mock_create.return_value = mock_session + + client = AsyncAgentRuntimeDataPlaneClient( + router_url="http://router", + namespace="default", + agent_name="agent-a", + connect_timeout=1.0, + timeout=2, + ) + resp = await client.invoke(session_id="sid", payload={"p": 1}) + self.assertIs(resp, mock_resp) + + mock_session.post.assert_awaited_once() + call_kwargs = mock_session.post.call_args.kwargs + self.assertEqual(call_kwargs["headers"]["x-agentcube-session-id"], "sid") + self.assertEqual(call_kwargs["json"], {"p": 1}) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk-python/tests/test_async_code_interpreter.py b/sdk-python/tests/test_async_code_interpreter.py new file mode 100644 index 00000000..c9190caf --- /dev/null +++ b/sdk-python/tests/test_async_code_interpreter.py @@ -0,0 +1,148 @@ +# Copyright The Volcano Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for AsyncCodeInterpreterClient session management. + +Tests cover: +- Session creation +- Session reuse +- Context manager behavior +- Error handling / resource cleanup +""" + +import os +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +# Set required env var before import +os.environ.setdefault("ROUTER_URL", "http://mock-router:8080") + +from agentcube.async_code_interpreter import AsyncCodeInterpreterClient + + +def _make_async_cp(session_id="new-session-123"): + """Return a mock AsyncControlPlaneClient.""" + mock_cp = MagicMock() + mock_cp.create_session = AsyncMock(return_value=session_id) + mock_cp.delete_session = AsyncMock(return_value=True) + mock_cp.close = AsyncMock() + mock_cp.logger = MagicMock() + return mock_cp + + +def _make_async_dp(): + """Return a mock AsyncCodeInterpreterDataPlaneClient.""" + mock_dp = MagicMock() + mock_dp.close = AsyncMock() + mock_dp.logger = MagicMock() + return mock_dp + + +class TestAsyncCodeInterpreterClientCreate(unittest.IsolatedAsyncioTestCase): + """Test client creation via the async factory classmethod.""" + + @patch("agentcube.async_code_interpreter.AsyncCodeInterpreterDataPlaneClient") + @patch("agentcube.async_code_interpreter.AsyncControlPlaneClient") + async def test_create_creates_session(self, mock_cp_class, mock_dp_class): + """AsyncCodeInterpreterClient.create() should create a session.""" + mock_cp = _make_async_cp("new-session-123") + mock_cp_class.return_value = mock_cp + + client = await AsyncCodeInterpreterClient.create(router_url="http://test:8080") + + self.assertEqual(client.session_id, "new-session-123") + mock_cp.create_session.assert_awaited_once() + mock_dp_class.assert_called_once() + + @patch("agentcube.async_code_interpreter.AsyncCodeInterpreterDataPlaneClient") + @patch("agentcube.async_code_interpreter.AsyncControlPlaneClient") + async def test_create_with_session_id_reuses_session(self, mock_cp_class, mock_dp_class): + """Providing session_id should reuse existing session.""" + mock_cp = _make_async_cp() + mock_cp_class.return_value = mock_cp + + client = await AsyncCodeInterpreterClient.create( + router_url="http://test:8080", + session_id="existing-session-123", + ) + + self.assertEqual(client.session_id, "existing-session-123") + mock_cp.create_session.assert_not_awaited() + mock_dp_class.assert_called_once() + + +class TestAsyncCodeInterpreterContextManager(unittest.IsolatedAsyncioTestCase): + """Test async context manager behavior.""" + + @patch("agentcube.async_code_interpreter.AsyncCodeInterpreterDataPlaneClient") + @patch("agentcube.async_code_interpreter.AsyncControlPlaneClient") + async def test_context_manager_calls_stop(self, mock_cp_class, mock_dp_class): + """Async context manager should call stop() on exit.""" + mock_cp = _make_async_cp("ctx-session-123") + mock_cp_class.return_value = mock_cp + + mock_dp = _make_async_dp() + mock_dp_class.return_value = mock_dp + + async with AsyncCodeInterpreterClient(router_url="http://test:8080") as _client: + pass + + mock_cp.delete_session.assert_awaited_once_with("ctx-session-123") + mock_dp.close.assert_awaited_once() + mock_cp.close.assert_awaited_once() + + +class TestAsyncCodeInterpreterSessionReuse(unittest.IsolatedAsyncioTestCase): + """Test session reuse across client instances.""" + + @patch("agentcube.async_code_interpreter.AsyncCodeInterpreterDataPlaneClient") + @patch("agentcube.async_code_interpreter.AsyncControlPlaneClient") + async def test_reuse_session_no_new_creation(self, mock_cp_class, mock_dp_class): + """Reusing session_id should not create a new session.""" + mock_cp = _make_async_cp() + mock_cp_class.return_value = mock_cp + + _client = await AsyncCodeInterpreterClient.create( + router_url="http://test:8080", + session_id="reused-session-789", + ) + + mock_cp.create_session.assert_not_awaited() + mock_dp_class.assert_called_once() + call_kwargs = mock_dp_class.call_args[1] + self.assertEqual(call_kwargs["session_id"], "reused-session-789") + + +class TestAsyncCodeInterpreterResourceLeakPrevention(unittest.IsolatedAsyncioTestCase): + """Test that resources are cleaned up on failure.""" + + @patch("agentcube.async_code_interpreter.AsyncCodeInterpreterDataPlaneClient") + @patch("agentcube.async_code_interpreter.AsyncControlPlaneClient") + async def test_cleanup_on_dp_init_failure(self, mock_cp_class, mock_dp_class): + """Session should be deleted if data plane client init fails.""" + mock_cp = _make_async_cp("leaked-session-999") + mock_cp_class.return_value = mock_cp + + mock_dp_class.side_effect = Exception("Connection failed") + + with self.assertRaises(Exception) as ctx: + await AsyncCodeInterpreterClient.create(router_url="http://test:8080") + + self.assertIn("Connection failed", str(ctx.exception)) + mock_cp.delete_session.assert_awaited_once_with("leaked-session-999") + + +if __name__ == "__main__": + unittest.main()