diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md index 198376bcbb..18bbd8dd17 100644 --- a/python/packages/azure-cosmos/README.md +++ b/python/packages/azure-cosmos/README.md @@ -35,4 +35,93 @@ Container naming behavior: - Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`) - `session_id` is used as the Cosmos partition key for reads/writes -See `samples/cosmos_history_provider.py` for a runnable package-local example. +See `samples/history_provider/cosmos_history_basic.py` for a runnable package-local example. + +## Cosmos DB Workflow Checkpoint Storage + +`CosmosCheckpointStorage` implements the `CheckpointStorage` protocol, enabling +durable workflow checkpointing backed by Azure Cosmos DB NoSQL. Workflows can be +paused and resumed across process restarts by persisting checkpoint state in Cosmos DB. + +### Basic Usage + +#### Managed Identity / RBAC (recommended for production) + +```python +from azure.identity.aio import DefaultAzureCredential +from agent_framework import WorkflowBuilder +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +checkpoint_storage = CosmosCheckpointStorage( + endpoint="https://.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-framework", + container_name="workflow-checkpoints", +) +``` + +#### Account Key + +```python +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +checkpoint_storage = CosmosCheckpointStorage( + endpoint="https://.documents.azure.com:443/", + credential="", + database_name="agent-framework", + container_name="workflow-checkpoints", +) +``` + +#### Then use with a workflow + +```python +from agent_framework import WorkflowBuilder + +# Build a workflow with checkpointing enabled +workflow = WorkflowBuilder( + start_executor=start, + checkpoint_storage=checkpoint_storage, +).build() + +# Run the workflow — checkpoints are automatically saved after each superstep +result = await workflow.run(message="input data") + +# Resume from a checkpoint +latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) +if latest: + resumed = await workflow.run(checkpoint_id=latest.checkpoint_id) +``` + +### Authentication Options + +`CosmosCheckpointStorage` supports the same authentication modes as `CosmosHistoryProvider`: + +- **Managed identity / RBAC** (recommended): Pass `DefaultAzureCredential()`, + `ManagedIdentityCredential()`, or any Azure `TokenCredential` +- **Account key**: Pass a key string via `credential` parameter +- **Environment variables**: Set `AZURE_COSMOS_ENDPOINT`, `AZURE_COSMOS_DATABASE_NAME`, + `AZURE_COSMOS_CONTAINER_NAME`, and `AZURE_COSMOS_KEY` (key not required when using + Azure credentials) +- **Pre-created client**: Pass an existing `CosmosClient` or `ContainerProxy` + +### Database and Container Setup + +The database and container are created automatically on first use (via +`create_database_if_not_exists` and `create_container_if_not_exists`). The container +uses `/workflow_name` as the partition key. You can also pre-create them in the Azure +portal with this partition key configuration. + +### Environment Variables + +| Variable | Description | +|---|---| +| `AZURE_COSMOS_ENDPOINT` | Cosmos DB account endpoint | +| `AZURE_COSMOS_DATABASE_NAME` | Database name | +| `AZURE_COSMOS_CONTAINER_NAME` | Container name | +| `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) | + +See `samples/checkpoint_storage/cosmos_checkpoint_workflow.py` for a standalone example, +`samples/checkpoint_storage/cosmos_checkpoint_foundry.py` for an end-to-end example +with Azure AI Foundry agents, or `samples/cosmos_e2e_foundry.py` for both +history and checkpointing together. diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py index 5bcfb3928b..66373b0f1d 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py @@ -2,6 +2,7 @@ import importlib.metadata +from ._checkpoint_storage import CosmosCheckpointStorage from ._history_provider import CosmosHistoryProvider try: @@ -10,6 +11,7 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ + "CosmosCheckpointStorage", "CosmosHistoryProvider", "__version__", ] diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py new file mode 100644 index 0000000000..08db5d51b9 --- /dev/null +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -0,0 +1,426 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Cosmos DB checkpoint storage for workflow checkpointing.""" + +from __future__ import annotations + +import logging +from typing import Any, TypedDict + +from agent_framework import AGENT_FRAMEWORK_USER_AGENT +from agent_framework._settings import SecretString, load_settings +from agent_framework._workflows._checkpoint import CheckpointID, WorkflowCheckpoint +from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from agent_framework.exceptions import WorkflowCheckpointException +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.cosmos import PartitionKey +from azure.cosmos.aio import ContainerProxy, CosmosClient +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +AzureCredentialTypes = TokenCredential | AsyncTokenCredential + +logger = logging.getLogger(__name__) + + +class AzureCosmosCheckpointSettings(TypedDict, total=False): + """Settings for CosmosCheckpointStorage resolved from args and environment.""" + + endpoint: str | None + database_name: str | None + container_name: str | None + key: SecretString | None + + +class CosmosCheckpointStorage: + """Azure Cosmos DB-backed checkpoint storage for workflow checkpointing. + + Implements the ``CheckpointStorage`` protocol using Azure Cosmos DB NoSQL + as the persistent backend. Checkpoints are stored as JSON documents with + ``workflow_name`` as the partition key, enabling efficient per-workflow queries. + + This storage uses the same hybrid JSON + pickle encoding as + ``FileCheckpointStorage``, allowing full Python object fidelity for + complex workflow state while keeping the document structure human-readable. + + SECURITY WARNING: Checkpoints use pickle for data serialization. Only load + checkpoints from trusted sources. Loading a malicious checkpoint can execute + arbitrary code. + + The database and container are created automatically on first use + if they do not already exist. The container uses partition key + ``/workflow_name``. + + Example using managed identity / RBAC:: + + from azure.identity.aio import DefaultAzureCredential + from agent_framework_azure_cosmos import CosmosCheckpointStorage + + storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-db", + container_name="checkpoints", + ) + + Example using account key:: + + storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential="my-account-key", + database_name="agent-db", + container_name="checkpoints", + ) + + Then use with a workflow builder:: + + workflow = WorkflowBuilder( + start_executor=start, + checkpoint_storage=storage, + ).build() + """ + + def __init__( + self, + *, + endpoint: str | None = None, + database_name: str | None = None, + container_name: str | None = None, + credential: str | AzureCredentialTypes | None = None, + cosmos_client: CosmosClient | None = None, + container_client: ContainerProxy | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize the Azure Cosmos DB checkpoint storage. + + Supports multiple authentication modes: + + - **Container client** (``container_client``): Use a pre-created + Cosmos async container proxy. No client lifecycle is managed. + - **Cosmos client** (``cosmos_client``): Use a pre-created Cosmos + async client. The caller is responsible for closing it. + - **Endpoint + credential**: Create a new Cosmos client. The storage + owns the client and closes it on ``close()``. + - **Environment variables**: Falls back to ``AZURE_COSMOS_ENDPOINT``, + ``AZURE_COSMOS_DATABASE_NAME``, ``AZURE_COSMOS_CONTAINER_NAME``, + and ``AZURE_COSMOS_KEY``. + + Args: + endpoint: Cosmos DB account endpoint. + Can be set via ``AZURE_COSMOS_ENDPOINT``. + database_name: Cosmos DB database name. + Can be set via ``AZURE_COSMOS_DATABASE_NAME``. + container_name: Cosmos DB container name. + Can be set via ``AZURE_COSMOS_CONTAINER_NAME``. + credential: Credential to authenticate with Cosmos DB. + For **managed identity / RBAC**, pass an Azure credential object + such as ``DefaultAzureCredential()`` or + ``ManagedIdentityCredential()``. + For **key-based auth**, pass the account key as a string, + or set ``AZURE_COSMOS_KEY`` in the environment. + cosmos_client: Pre-created Cosmos async client. + container_client: Pre-created Cosmos container client. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + """ + self._cosmos_client: CosmosClient | None = cosmos_client + self._container_proxy: ContainerProxy | None = container_client + self._owns_client = False + + if self._container_proxy is not None: + self.database_name: str = database_name or "" + self.container_name: str = container_name or "" + return + + required_fields: list[str] = ["database_name", "container_name"] + if cosmos_client is None: + required_fields.append("endpoint") + if credential is None: + required_fields.append("key") + + settings = load_settings( + AzureCosmosCheckpointSettings, + env_prefix="AZURE_COSMOS_", + required_fields=required_fields, + endpoint=endpoint, + database_name=database_name, + container_name=container_name, + key=credential if isinstance(credential, str) else None, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + self.database_name = settings["database_name"] # type: ignore[assignment] + self.container_name = settings["container_name"] # type: ignore[assignment] + + if self._cosmos_client is None: + self._cosmos_client = CosmosClient( + url=settings["endpoint"], # type: ignore[arg-type] + credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr] + user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT, + ) + self._owns_client = True + + async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID: + """Save a checkpoint to Cosmos DB and return its ID. + + The checkpoint is encoded to a JSON-compatible form (using pickle for + non-JSON-native values) and stored as a Cosmos DB document with the + ``workflow_name`` as the partition key. + + The document ``id`` is a composite of ``workflow_name`` and + ``checkpoint_id`` to ensure global uniqueness across partitions. + + Args: + checkpoint: The WorkflowCheckpoint object to save. + + Returns: + The unique ID of the saved checkpoint. + """ + await self._ensure_container_proxy() + + checkpoint_dict = checkpoint.to_dict() + encoded = encode_checkpoint_value(checkpoint_dict) + + document: dict[str, Any] = { + "id": self._make_document_id(checkpoint.workflow_name, checkpoint.checkpoint_id), + "workflow_name": checkpoint.workflow_name, + **encoded, + } + + await self._container_proxy.upsert_item(body=document) # type: ignore[union-attr] + logger.info("Saved checkpoint %s to Cosmos DB", checkpoint.checkpoint_id) + return checkpoint.checkpoint_id + + async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint: + """Load a checkpoint from Cosmos DB by ID. + + Args: + checkpoint_id: The unique ID of the checkpoint to load. + + Returns: + The WorkflowCheckpoint object corresponding to the given ID. + + Raises: + WorkflowCheckpointException: If no checkpoint with the given ID exists, + or if multiple checkpoints share the same ID across workflows. + """ + await self._ensure_container_proxy() + + query = "SELECT * FROM c WHERE c.checkpoint_id = @checkpoint_id" + parameters: list[dict[str, object]] = [ + {"name": "@checkpoint_id", "value": checkpoint_id}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + ) + + results: list[dict[str, Any]] = [] + async for item in items: + results.append(item) + + if not results: + raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}") + + if len(results) > 1: + workflow_names = [r.get("workflow_name", "unknown") for r in results] + raise WorkflowCheckpointException( + f"Multiple checkpoints found with ID {checkpoint_id} across workflows: " + f"{workflow_names}. Use list_checkpoints(workflow_name=...) to query " + f"by workflow instead." + ) + + return self._document_to_checkpoint(results[0]) + + async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: + """List checkpoint objects for a given workflow name. + + Args: + workflow_name: The name of the workflow to list checkpoints for. + + Returns: + A list of WorkflowCheckpoint objects for the specified workflow name. + """ + await self._ensure_container_proxy() + + query = "SELECT * FROM c WHERE c.workflow_name = @workflow_name ORDER BY c.timestamp ASC" + parameters: list[dict[str, object]] = [ + {"name": "@workflow_name", "value": workflow_name}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + partition_key=workflow_name, + ) + + checkpoints: list[WorkflowCheckpoint] = [] + async for item in items: + try: + checkpoints.append(self._document_to_checkpoint(item)) + except Exception as e: + logger.warning("Failed to decode checkpoint document: %s", e) + return checkpoints + + async def delete(self, checkpoint_id: CheckpointID) -> bool: + """Delete a checkpoint from Cosmos DB by ID. + + Args: + checkpoint_id: The unique ID of the checkpoint to delete. + + Returns: + True if the checkpoint was successfully deleted, False if not found. + """ + await self._ensure_container_proxy() + + query = "SELECT c.id, c.workflow_name FROM c WHERE c.checkpoint_id = @checkpoint_id" + parameters: list[dict[str, object]] = [ + {"name": "@checkpoint_id", "value": checkpoint_id}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + ) + + async for item in items: + try: + await self._container_proxy.delete_item( # type: ignore[union-attr] + item=item["id"], + partition_key=item["workflow_name"], + ) + logger.info("Deleted checkpoint %s from Cosmos DB", checkpoint_id) + return True + except CosmosResourceNotFoundError: + return False + + return False + + async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: + """Get the latest checkpoint for a given workflow name. + + Args: + workflow_name: The name of the workflow to get the latest checkpoint for. + + Returns: + The latest WorkflowCheckpoint, or None if no checkpoints exist. + """ + await self._ensure_container_proxy() + + query = ( + "SELECT * FROM c WHERE c.workflow_name = @workflow_name " + "ORDER BY c.timestamp DESC OFFSET 0 LIMIT 1" + ) + parameters: list[dict[str, object]] = [ + {"name": "@workflow_name", "value": workflow_name}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + partition_key=workflow_name, + ) + + async for item in items: + checkpoint = self._document_to_checkpoint(item) + logger.debug( + "Latest checkpoint for workflow %s is %s", + workflow_name, + checkpoint.checkpoint_id, + ) + return checkpoint + + return None + + async def list_checkpoint_ids(self, *, workflow_name: str) -> list[CheckpointID]: + """List checkpoint IDs for a given workflow name. + + Args: + workflow_name: The name of the workflow to list checkpoint IDs for. + + Returns: + A list of checkpoint IDs for the specified workflow name. + """ + await self._ensure_container_proxy() + + query = ( + "SELECT c.checkpoint_id FROM c WHERE c.workflow_name = @workflow_name " + "ORDER BY c.timestamp ASC" + ) + parameters: list[dict[str, object]] = [ + {"name": "@workflow_name", "value": workflow_name}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + partition_key=workflow_name, + ) + + checkpoint_ids: list[CheckpointID] = [] + async for item in items: + cid = item.get("checkpoint_id") + if isinstance(cid, str): + checkpoint_ids.append(cid) + return checkpoint_ids + + async def close(self) -> None: + """Close the underlying Cosmos client when this storage owns it.""" + if self._owns_client and self._cosmos_client is not None: + await self._cosmos_client.close() + + async def __aenter__(self) -> CosmosCheckpointStorage: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + try: + await self.close() + except Exception: + if exc_type is None: + raise + + async def _ensure_container_proxy(self) -> None: + """Get or create the Cosmos DB database and container for storing checkpoints.""" + if self._container_proxy is not None: + return + if self._cosmos_client is None: + raise RuntimeError("Cosmos client is not initialized.") + + database = await self._cosmos_client.create_database_if_not_exists(id=self.database_name) + self._container_proxy = await database.create_container_if_not_exists( + id=self.container_name, + partition_key=PartitionKey(path="/workflow_name"), + ) + + @staticmethod + def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: + """Convert a Cosmos DB document back to a WorkflowCheckpoint. + + Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``, + ``_attachments``, ``_ts``) before decoding. + """ + # Remove Cosmos DB system properties and the composite 'id' field + # (checkpoints use 'checkpoint_id', not 'id') + cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"} + cleaned = {k: v for k, v in document.items() if k not in cosmos_keys} + + decoded = decode_checkpoint_value(cleaned) + return WorkflowCheckpoint.from_dict(decoded) + + @staticmethod + def _make_document_id(workflow_name: str, checkpoint_id: str) -> str: + """Create a composite Cosmos DB document ID. + + Combines ``workflow_name`` and ``checkpoint_id`` to ensure global + uniqueness across partitions. + """ + return f"{workflow_name}_{checkpoint_id}" diff --git a/python/packages/azure-cosmos/pyproject.toml b/python/packages/azure-cosmos/pyproject.toml index cbb8188de0..2e129236fe 100644 --- a/python/packages/azure-cosmos/pyproject.toml +++ b/python/packages/azure-cosmos/pyproject.toml @@ -84,6 +84,7 @@ exclude_dirs = ["tests"] [tool.poe] executor.type = "uv" include = "../../shared_tasks.toml" + [tool.poe.tasks.mypy] help = "Run MyPy for this package." cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_azure_cosmos" @@ -94,7 +95,7 @@ cmd = "pytest -m \"not integration\" --cov=agent_framework_azure_cosmos --cov-re [tool.poe.tasks.integration-tests] help = "Run the package integration test suite." -cmd = "pytest tests/test_cosmos_history_provider.py -m integration" +cmd = "pytest tests/test_cosmos_history_provider.py tests/test_cosmos_checkpoint_storage.py -m integration" [build-system] requires = ["flit-core >= 3.11,<4.0"] diff --git a/python/packages/azure-cosmos/samples/.env.template b/python/packages/azure-cosmos/samples/.env.template new file mode 100644 index 0000000000..c200634623 --- /dev/null +++ b/python/packages/azure-cosmos/samples/.env.template @@ -0,0 +1,6 @@ +AZURE_AI_PROJECT_ENDPOINT +AZURE_AI_MODEL_DEPLOYMENT_NAME +AZURE_COSMOS_ENDPOINT +AZURE_COSMOS_KEY +AZURE_COSMOS_DATABASE_NAME +AZURE_COSMOS_CONTAINER_NAME diff --git a/python/packages/azure-cosmos/samples/README.md b/python/packages/azure-cosmos/samples/README.md index 082a9c2cfe..4c448f159f 100644 --- a/python/packages/azure-cosmos/samples/README.md +++ b/python/packages/azure-cosmos/samples/README.md @@ -2,19 +2,55 @@ This folder contains samples for `agent-framework-azure-cosmos`. +## History Provider Samples + +Demonstrate conversation persistence using `CosmosHistoryProvider`. + +| File | Description | +| --- | --- | +| [`history_provider/cosmos_history_basic.py`](history_provider/cosmos_history_basic.py) | Basic multi-turn conversation using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient`, provider-configured container name, and `session_id` partitioning. | +| [`history_provider/cosmos_history_conversation_persistence.py`](history_provider/cosmos_history_conversation_persistence.py) | Persist and resume conversations across application restarts — serialize session state, create new provider/agent instances, and continue from Cosmos DB history. | +| [`history_provider/cosmos_history_messages.py`](history_provider/cosmos_history_messages.py) | Direct message history operations — retrieve stored messages as a transcript, clear session history, and verify data deletion. | +| [`history_provider/cosmos_history_sessions.py`](history_provider/cosmos_history_sessions.py) | Multi-session and multi-tenant management — per-tenant session isolation, `list_sessions()` to enumerate, switch between sessions, and resume specific conversations. | + +## Checkpoint Storage Samples + +Demonstrate workflow pause/resume using `CosmosCheckpointStorage`. + | File | Description | | --- | --- | -| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Demonstrates an Agent using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient` (project endpoint), provider-configured container name, and `session_id` partitioning. | +| [`checkpoint_storage/cosmos_checkpoint_workflow.py`](checkpoint_storage/cosmos_checkpoint_workflow.py) | Workflow checkpoint storage with Cosmos DB — pause and resume workflows across restarts using `CosmosCheckpointStorage`, with support for key-based and managed identity auth. | +| [`checkpoint_storage/cosmos_checkpoint_foundry.py`](checkpoint_storage/cosmos_checkpoint_foundry.py) | End-to-end Azure AI Foundry + Cosmos DB checkpointing — multi-agent workflow using `AzureOpenAIResponsesClient` with `CosmosCheckpointStorage` for durable pause/resume. | + +## Combined Sample + +| File | Description | +| --- | --- | +| [`cosmos_e2e_foundry.py`](cosmos_e2e_foundry.py) | Both `CosmosHistoryProvider` and `CosmosCheckpointStorage` in a single Azure AI Foundry agent app — the recommended production pattern for fully durable agent workflows. | ## Prerequisites - `AZURE_COSMOS_ENDPOINT` - `AZURE_COSMOS_DATABASE_NAME` -- `AZURE_COSMOS_CONTAINER_NAME` - `AZURE_COSMOS_KEY` (or equivalent credential flow) +For Foundry samples, also set: +- `AZURE_AI_PROJECT_ENDPOINT` +- `AZURE_AI_MODEL_DEPLOYMENT_NAME` + ## Run ```bash -uv run --directory packages/azure-cosmos python samples/cosmos_history_provider.py +# History provider samples +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_basic.py +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_conversation_persistence.py +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_messages.py +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_sessions.py + +# Checkpoint storage samples +uv run --directory packages/azure-cosmos python samples/checkpoint_storage/cosmos_checkpoint_workflow.py +uv run --directory packages/azure-cosmos python samples/checkpoint_storage/cosmos_checkpoint_foundry.py + +# Combined sample +uv run --directory packages/azure-cosmos python samples/cosmos_e2e_foundry.py ``` diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py b/python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py new file mode 100644 index 0000000000..7c19d114a0 --- /dev/null +++ b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +"""Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry. + +Purpose: +This sample demonstrates how to use CosmosCheckpointStorage with agents built +on Azure AI Foundry (via AzureOpenAIResponsesClient). It shows a multi-agent +workflow where checkpoint state is persisted to Cosmos DB, enabling durable +pause-and-resume across process restarts. + +What you learn: +- How to wire CosmosCheckpointStorage with AzureOpenAIResponsesClient agents +- How to combine session history with workflow checkpointing +- How to resume a workflow-as-agent from a Cosmos DB checkpoint + +Key concepts: +- AgentSession: Maintains conversation history across agent invocations +- CosmosCheckpointStorage: Persists workflow execution state in Cosmos DB +- These are complementary: sessions track conversation, checkpoints track workflow state + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint + AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name + AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +from typing import Any + +from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +load_dotenv() + + +async def main() -> None: + """Run the Azure AI Foundry + Cosmos DB checkpointing sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT",) + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not project_endpoint or not deployment_name: + print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") + return + + if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: + print( + "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " + "and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # Use a single AzureCliCredential for both Cosmos and Foundry, + # properly closed via async context manager. + async with AzureCliCredential() as azure_credential: + cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential + + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + # Create Azure AI Foundry agents + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=azure_credential, + ) + + assistant = client.as_agent( + name="assistant", + instructions="You are a helpful assistant. Keep responses brief.", + ) + + reviewer = client.as_agent( + name="reviewer", + instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", + ) + + # Build a sequential workflow and wrap it as an agent + workflow = SequentialBuilder(participants=[assistant, reviewer]).build() + agent = workflow.as_agent(name="FoundryCheckpointedAgent") + + # --- First run: execute with Cosmos DB checkpointing --- + print("=== First Run ===\n") + + session = agent.create_session() + query = "What are the benefits of renewable energy?" + print(f"User: {query}") + + response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage) + + for msg in response.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show checkpoints persisted in Cosmos DB + checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}") + for i, cp in enumerate(checkpoints[:5], 1): + print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})") + + # --- Second run: continue conversation with checkpoint history --- + print("\n=== Second Run (continuing conversation) ===\n") + + query2 = "Can you elaborate on the economic benefits?" + print(f"User: {query2}") + + response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) + + for msg in response2.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show total checkpoints + all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}") + + # Get latest checkpoint + latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) + if latest: + print(f"Latest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py new file mode 100644 index 0000000000..4726742ffc --- /dev/null +++ b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +"""Sample: Workflow Checkpointing with Cosmos DB NoSQL. + +Purpose: +This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint +storage backend for workflows, enabling durable pause-and-resume across +process restarts. + +What you learn: +- How to configure CosmosCheckpointStorage for workflow checkpointing +- How to run a workflow that automatically persists checkpoints to Cosmos DB +- How to resume a workflow from a Cosmos DB checkpoint +- How to list and inspect available checkpoints + +Prerequisites: +- An Azure Cosmos DB account (or local emulator) +- Environment variables set (see below) + +Environment variables: + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name + AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +import sys +from dataclasses import dataclass +from typing import Any + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowCheckpoint, + WorkflowContext, + handler, +) + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + +from agent_framework_azure_cosmos import CosmosCheckpointStorage + + +@dataclass +class ComputeTask: + """Task containing the list of numbers remaining to be processed.""" + + remaining_numbers: list[int] + + +class StartExecutor(Executor): + """Initiates the workflow by providing the upper limit.""" + + @handler + async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None: + """Start the workflow with numbers up to the given limit.""" + print(f"StartExecutor: Starting computation up to {upper_limit}") + await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1)))) + + +class WorkerExecutor(Executor): + """Processes numbers and manages executor state for checkpointing.""" + + def __init__(self, id: str) -> None: + """Initialize the worker executor.""" + super().__init__(id=id) + self._results: dict[int, list[tuple[int, int]]] = {} + + @handler + async def compute( + self, + task: ComputeTask, + ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]], + ) -> None: + """Process the next number, computing its factor pairs.""" + next_number = task.remaining_numbers.pop(0) + print(f"WorkerExecutor: Processing {next_number}") + + pairs: list[tuple[int, int]] = [] + for i in range(1, next_number): + if next_number % i == 0: + pairs.append((i, next_number // i)) + self._results[next_number] = pairs + + if not task.remaining_numbers: + await ctx.yield_output(self._results) + else: + await ctx.send_message(task) + + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"results": self._results} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._results = state.get("results", {}) + + +async def main() -> None: + """Run the workflow checkpointing sample with Cosmos DB.""" + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: + print( + "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " + "and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # Authentication: supports both managed identity/RBAC and key-based auth. + # When AZURE_COSMOS_KEY is set, key-based auth is used. + # Otherwise, falls back to DefaultAzureCredential (properly closed via async with). + if cosmos_key: + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_key, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + await _run_workflow(checkpoint_storage) + else: + from azure.identity.aio import DefaultAzureCredential + + async with DefaultAzureCredential() as credential, CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + await _run_workflow(checkpoint_storage) + + +async def _run_workflow(checkpoint_storage: CosmosCheckpointStorage) -> None: + """Build and run the workflow with Cosmos DB checkpointing.""" + start = StartExecutor(id="start") + worker = WorkerExecutor(id="worker") + workflow_builder = ( + WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) + .add_edge(start, worker) + .add_edge(worker, worker) + ) + + # --- First run: execute the workflow --- + print("\n=== First Run ===\n") + workflow = workflow_builder.build() + + output = None + async for event in workflow.run(message=8, stream=True): + if event.type == "output": + output = event.data + + print(f"Factor pairs computed: {output}") + + # List checkpoints saved in Cosmos DB + checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( + workflow_name=workflow.name, + ) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") + for cid in checkpoint_ids: + print(f" - {cid}") + + # Get the latest checkpoint + latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( + workflow_name=workflow.name, + ) + + if latest is None: + print("No checkpoint found to resume from.") + return + + print(f"\nLatest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") + + # --- Second run: resume from the latest checkpoint --- + print("\n=== Resuming from Checkpoint ===\n") + workflow2 = workflow_builder.build() + + output2 = None + async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): + if event.type == "output": + output2 = event.data + + if output2: + print(f"Resumed workflow produced: {output2}") + else: + print("Resumed workflow completed (no remaining work — already finished).") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py b/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py new file mode 100644 index 0000000000..1e8a20c388 --- /dev/null +++ b/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +"""Sample: Combined History Provider + Checkpoint Storage with Azure AI Foundry. + +Purpose: +This sample demonstrates using both CosmosHistoryProvider (conversation memory) +and CosmosCheckpointStorage (workflow pause/resume) together in a single Azure AI +Foundry agent application. This is the recommended production pattern for customers +who need both durable conversations and durable workflow execution. + +What you learn: +- How to wire both CosmosHistoryProvider and CosmosCheckpointStorage in one app +- How conversation history and workflow checkpoints serve complementary roles +- How to resume both conversation context and workflow execution state + +Key concepts: +- CosmosHistoryProvider: Persists conversation messages across sessions +- CosmosCheckpointStorage: Persists workflow execution state for pause/resume +- Together they enable fully durable agent workflows + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint + AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +from typing import Any + +from agent_framework import WorkflowAgent, WorkflowCheckpoint +from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosCheckpointStorage, CosmosHistoryProvider + +load_dotenv() + + +async def main() -> None: + """Run the combined history + checkpoint sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not project_endpoint or not deployment_name: + print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") + return + + if not cosmos_endpoint or not cosmos_database_name: + print("Please set AZURE_COSMOS_ENDPOINT and AZURE_COSMOS_DATABASE_NAME.") + return + + async with AzureCliCredential() as azure_credential: + cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential + + # CosmosHistoryProvider: stores conversation messages + # CosmosCheckpointStorage: stores workflow execution state + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name="conversation-history", + credential=cosmos_credential, + ) as history_provider, + CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_credential, + database_name=cosmos_database_name, + container_name="workflow-checkpoints", + ) as checkpoint_storage, + ): + # Create Azure AI Foundry agents + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=azure_credential, + ) + + assistant = client.as_agent( + name="assistant", + instructions="You are a helpful assistant. Keep responses brief.", + ) + + reviewer = client.as_agent( + name="reviewer", + instructions=( + "You are a reviewer. Provide a one-sentence " + "summary of the assistant's response." + ), + ) + + # Build a workflow with both history and checkpointing. + # Attach the history provider to the WorkflowAgent (outer agent) + # so conversation messages are persisted at the agent level. + workflow = SequentialBuilder( + participants=[assistant, reviewer], + ).build() + agent = WorkflowAgent( + workflow, + name="DurableAgent", + context_providers=[history_provider], + ) + + # --- First run --- + print("=== First Run ===\n") + session = agent.create_session() + + response = await agent.run( + "What are three benefits of cloud computing?", + session=session, + checkpoint_storage=checkpoint_storage, + ) + + for msg in response.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show what's persisted + checkpoints = await checkpoint_storage.list_checkpoints( + workflow_name=workflow.name, + ) + history = await history_provider.get_messages(session.session_id) + + print(f"\nConversation messages in Cosmos DB: {len(history)}") + print(f"Workflow checkpoints in Cosmos DB: {len(checkpoints)}") + + # --- Second run: conversation context is loaded from history --- + print("\n=== Second Run (with conversation context) ===\n") + + response2 = await agent.run( + "Can you elaborate on the first benefit?", + session=session, + checkpoint_storage=checkpoint_storage, + ) + + for msg in response2.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show updated state + latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( + workflow_name=workflow.name, + ) + history2 = await history_provider.get_messages(session.session_id) + + print(f"\nConversation messages after 2 runs: {len(history2)}") + if latest: + print(f"Latest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/history_provider/__init__.py b/python/packages/azure-cosmos/samples/history_provider/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/azure-cosmos/samples/cosmos_history_provider.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py similarity index 94% rename from python/packages/azure-cosmos/samples/cosmos_history_provider.py rename to python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py index ff6138c1e5..755b8c92f5 100644 --- a/python/packages/azure-cosmos/samples/cosmos_history_provider.py +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py @@ -23,7 +23,7 @@ Environment variables: AZURE_AI_PROJECT_ENDPOINT - AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME + AZURE_AI_MODEL_DEPLOYMENT_NAME AZURE_COSMOS_ENDPOINT AZURE_COSMOS_DATABASE_NAME AZURE_COSMOS_CONTAINER_NAME @@ -35,7 +35,7 @@ async def main() -> None: """Run the Cosmos history provider sample with an Agent.""" project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") - deployment_name = os.getenv("AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") @@ -49,7 +49,7 @@ async def main() -> None: or not cosmos_container_name ): print( - "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME, " + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." ) return diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py new file mode 100644 index 0000000000..f8e97e8d17 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework import AgentSession +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates persisting and resuming conversations across application +restarts using CosmosHistoryProvider as the persistent backend. + +Key components: +- Phase 1: Run a conversation and serialize the session with session.to_dict() +- Phase 2: Simulate an app restart — create new provider and agent instances, + restore the session with AgentSession.from_dict(), and continue the conversation +- Cosmos DB reloads the full message history, so the agent remembers everything + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_AI_MODEL_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the conversation persistence sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # ── Phase 1: Initial conversation ── + + print("=== Phase 1: Initial conversation ===\n") + + # 1. Create credential, client, history provider, and agent. + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="PersistentAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 2. Create a session and have a multi-turn conversation. + session = agent.create_session() + + response1 = await agent.run( + "My name is Ada. I'm building a distributed database in Rust.", session=session + ) + print("User: My name is Ada. I'm building a distributed database in Rust.") + print(f"Assistant: {response1.text}\n") + + response2 = await agent.run("The hardest part is the consensus algorithm.", session=session) + print("User: The hardest part is the consensus algorithm.") + print(f"Assistant: {response2.text}\n") + + # 3. Serialize the session state — this is what you'd persist to a database or file. + serialized_session = session.to_dict() + print(f"Session serialized. Session ID: {session.session_id}") + + # ── Phase 2: Simulate app restart ── + + print("\n=== Phase 2: Resuming after 'restart' ===\n") + + # 4. Create entirely new provider and agent instances (simulating a fresh process). + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="PersistentAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 5. Restore the session from the serialized state. + restored_session = AgentSession.from_dict(serialized_session) + print(f"Session restored. Session ID: {restored_session.session_id}\n") + + # 6. Continue the conversation — history is reloaded from Cosmos DB. + response3 = await agent.run("What was I working on and what was the challenge?", session=restored_session) + print("User: What was I working on and what was the challenge?") + print(f"Assistant: {response3.text}\n") + + # 7. Verify messages are in Cosmos by reading them directly. + messages = await history_provider.get_messages(restored_session.session_id) + print(f"Messages stored in Cosmos DB: {len(messages)}") + for i, msg in enumerate(messages, 1): + print(f" {i}. [{msg.role}] {msg.text[:80]}...") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Phase 1: Initial conversation === + +User: My name is Ada. I'm building a distributed database in Rust. +Assistant: That sounds like a great project, Ada! Rust is an excellent choice for ... + +User: The hardest part is the consensus algorithm. +Assistant: Consensus algorithms can be tricky! Are you looking at Raft, Paxos, or ... + +Session serialized. Session ID: + +=== Phase 2: Resuming after 'restart' === + +Session restored. Session ID: + +User: What was I working on and what was the challenge? +Assistant: You told me you're building a distributed database in Rust and that the hardest +part is the consensus algorithm. + +Messages stored in Cosmos DB: 6 + 1. [user] My name is Ada. I'm building a distributed database in Rust.... + 2. [assistant] That sounds like a great project, Ada! Rust is an excellent ch... + 3. [user] The hardest part is the consensus algorithm.... + 4. [assistant] Consensus algorithms can be tricky! Are you looking at Raft, Pa... + 5. [user] What was I working on and what was the challenge?... + 6. [assistant] You told me you're building a distributed database in Rust and ... +""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py new file mode 100644 index 0000000000..c97504fd40 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates direct message history operations using +CosmosHistoryProvider — retrieving, displaying, and clearing stored messages. + +Key components: +- get_messages(session_id): Retrieve all stored messages as a chat transcript +- clear(session_id): Delete all messages for a session (e.g., GDPR compliance) +- Verifying that history is empty after clearing +- Running a new conversation in the same session after clearing + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_AI_MODEL_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the messages history sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="HistoryAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + session = agent.create_session() + session_id = session.session_id + + # 1. Have a multi-turn conversation. + print("=== Building a conversation ===\n") + + queries = [ + "Hi! My favorite programming language is Python.", + "I also enjoy hiking in the mountains on weekends.", + "What do you know about me so far?", + ] + for query in queries: + response = await agent.run(query, session=session) + print(f"User: {query}") + print(f"Assistant: {response.text}\n") + + # 2. Retrieve and display the full message history as a transcript. + print("=== Chat transcript from Cosmos DB ===\n") + + messages = await history_provider.get_messages(session_id) + print(f"Total messages stored: {len(messages)}\n") + for i, msg in enumerate(messages, 1): + print(f" {i}. [{msg.role}] {msg.text[:100]}") + + # 3. Clear the session history. + print("\n=== Clearing session history ===\n") + + await history_provider.clear(session_id) + print(f"Cleared all messages for session: {session_id}") + + # 4. Verify history is empty. + remaining = await history_provider.get_messages(session_id) + print(f"Messages after clear: {len(remaining)}") + + # 5. Start a fresh conversation in the same session — agent has no memory. + print("\n=== Fresh conversation (same session, no memory) ===\n") + + response = await agent.run("What do you know about me?", session=session) + print("User: What do you know about me?") + print(f"Assistant: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Building a conversation === + +User: Hi! My favorite programming language is Python. +Assistant: That's great! Python is a wonderful language. What do you like most about it? + +User: I also enjoy hiking in the mountains on weekends. +Assistant: Hiking sounds lovely! Do you have a favorite trail or mountain range? + +User: What do you know about me so far? +Assistant: You love Python as your favorite programming language and enjoy hiking in the mountains on weekends. + +=== Chat transcript from Cosmos DB === + +Total messages stored: 6 + + 1. [user] Hi! My favorite programming language is Python. + 2. [assistant] That's great! Python is a wonderful language. What do you like most about it? + 3. [user] I also enjoy hiking in the mountains on weekends. + 4. [assistant] Hiking sounds lovely! Do you have a favorite trail or mountain range? + 5. [user] What do you know about me so far? + 6. [assistant] You love Python as your favorite programming language and enjoy hiking ... + +=== Clearing session history === + +Cleared all messages for session: +Messages after clear: 0 + +=== Fresh conversation (same session, no memory) === + +User: What do you know about me? +Assistant: I don't have any information about you yet. Feel free to share anything you'd like! +""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py new file mode 100644 index 0000000000..6772d41825 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py @@ -0,0 +1,198 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates multi-session and multi-tenant management using +CosmosHistoryProvider. Each tenant (user) gets isolated conversation sessions +stored in the same Cosmos DB container, partitioned by session_id. + +Key components: +- Per-tenant session isolation using prefixed session IDs +- list_sessions(): Enumerate all stored sessions across tenants +- Switching between sessions for different users +- Resuming a specific user's session — verifying data isolation + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_AI_MODEL_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the session management sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="MultiTenantAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 1. Tenant "alice" starts a conversation about travel. + print("=== Tenant: Alice — Travel conversation ===\n") + + alice_session = agent.create_session(session_id="tenant-alice-session-1") + + response = await agent.run( + "Hi! I'm planning a trip to Italy. I love Renaissance art.", session=alice_session + ) + print("Alice: I'm planning a trip to Italy. I love Renaissance art.") + print(f"Assistant: {response.text}\n") + + response = await agent.run("Which museums should I visit in Florence?", session=alice_session) + print("Alice: Which museums should I visit in Florence?") + print(f"Assistant: {response.text}\n") + + # 2. Tenant "bob" starts a separate conversation about cooking. + print("=== Tenant: Bob — Cooking conversation ===\n") + + bob_session = agent.create_session(session_id="tenant-bob-session-1") + + response = await agent.run( + "Hey! I'm learning to cook Thai food. I just made pad thai.", session=bob_session + ) + print("Bob: I'm learning to cook Thai food. I just made pad thai.") + print(f"Assistant: {response.text}\n") + + response = await agent.run("What Thai dish should I try next?", session=bob_session) + print("Bob: What Thai dish should I try next?") + print(f"Assistant: {response.text}\n") + + # 3. List all sessions stored in Cosmos DB. + print("=== Listing all sessions ===\n") + + sessions = await history_provider.list_sessions() + print(f"Found {len(sessions)} session(s):") + for sid in sessions: + print(f" - {sid}") + + # 4. Resume Alice's session — verify she gets her travel context back. + print("\n=== Resuming Alice's session ===\n") + + alice_resumed = agent.create_session(session_id="tenant-alice-session-1") + + response = await agent.run("What were we discussing?", session=alice_resumed) + print("Alice: What were we discussing?") + print(f"Assistant: {response.text}\n") + + # 5. Resume Bob's session — verify he gets his cooking context back. + print("=== Resuming Bob's session ===\n") + + bob_resumed = agent.create_session(session_id="tenant-bob-session-1") + + response = await agent.run("What was the last dish I mentioned?", session=bob_resumed) + print("Bob: What was the last dish I mentioned?") + print(f"Assistant: {response.text}\n") + + # 6. Show per-session message counts. + print("=== Per-session message counts ===\n") + + alice_messages = await history_provider.get_messages("tenant-alice-session-1") + bob_messages = await history_provider.get_messages("tenant-bob-session-1") + print(f"Alice's session: {len(alice_messages)} messages") + print(f"Bob's session: {len(bob_messages)} messages") + + # 7. Clean up: clear both sessions. + print("\n=== Cleaning up ===\n") + + await history_provider.clear("tenant-alice-session-1") + await history_provider.clear("tenant-bob-session-1") + print("Cleared Alice's and Bob's sessions.") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Tenant: Alice — Travel conversation === + +Alice: I'm planning a trip to Italy. I love Renaissance art. +Assistant: Italy is a dream for Renaissance art lovers! Florence, Rome, and Venice ... + +Alice: Which museums should I visit in Florence? +Assistant: In Florence, the Uffizi Gallery is a must — it has Botticelli's Birth of Venus ... + +=== Tenant: Bob — Cooking conversation === + +Bob: I'm learning to cook Thai food. I just made pad thai. +Assistant: Pad thai is a great start! How did it turn out? + +Bob: What Thai dish should I try next? +Assistant: I'd suggest trying green curry or tom yum soup — both are classic Thai dishes ... + +=== Listing all sessions === + +Found 2 session(s): + - tenant-alice-session-1 + - tenant-bob-session-1 + +=== Resuming Alice's session === + +Alice: What were we discussing? +Assistant: We were discussing your trip to Italy and your love for Renaissance art ... + +=== Resuming Bob's session === + +Bob: What was the last dish I mentioned? +Assistant: You mentioned pad thai — it was the dish you just made! + +=== Per-session message counts === + +Alice's session: 6 messages +Bob's session: 6 messages + +=== Cleaning up === + +Cleared Alice's and Bob's sessions. +""" diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py new file mode 100644 index 0000000000..4249214a45 --- /dev/null +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -0,0 +1,574 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import os +import uuid +from collections.abc import AsyncIterator +from contextlib import suppress +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework._workflows._checkpoint import WorkflowCheckpoint +from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value +from agent_framework.exceptions import SettingNotFoundError, WorkflowCheckpointException +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +import agent_framework_azure_cosmos._checkpoint_storage as checkpoint_storage_module +from agent_framework_azure_cosmos._checkpoint_storage import CosmosCheckpointStorage + +skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif( + any( + os.getenv(name, "") == "" + for name in ( + "AZURE_COSMOS_ENDPOINT", + "AZURE_COSMOS_KEY", + "AZURE_COSMOS_DATABASE_NAME", + "AZURE_COSMOS_CONTAINER_NAME", + ) + ), + reason=( + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_KEY, AZURE_COSMOS_DATABASE_NAME, and " + "AZURE_COSMOS_CONTAINER_NAME are required for Cosmos integration tests." + ), +) + + +def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]: + async def _iterator() -> AsyncIterator[Any]: + for item in items: + yield item + + return _iterator() + + +def _make_checkpoint( + workflow_name: str = "test-workflow", + checkpoint_id: str | None = None, + previous_checkpoint_id: str | None = None, + timestamp: str | None = None, +) -> WorkflowCheckpoint: + """Create a minimal WorkflowCheckpoint for testing.""" + return WorkflowCheckpoint( + workflow_name=workflow_name, + graph_signature_hash="abc123", + checkpoint_id=checkpoint_id or str(uuid.uuid4()), + previous_checkpoint_id=previous_checkpoint_id, + timestamp=timestamp or "2025-01-01T00:00:00+00:00", + state={"counter": 42}, + iteration_count=1, + ) + + +def _checkpoint_to_cosmos_document(checkpoint: WorkflowCheckpoint) -> dict[str, Any]: + """Simulate what a Cosmos DB document looks like after save.""" + encoded = encode_checkpoint_value(checkpoint.to_dict()) + doc: dict[str, Any] = { + "id": checkpoint.checkpoint_id, + "workflow_name": checkpoint.workflow_name, + **encoded, + # Cosmos system properties + "_rid": "abc", + "_self": "dbs/abc/colls/def/docs/ghi", + "_etag": '"00000000-0000-0000-0000-000000000000"', + "_attachments": "attachments/", + "_ts": 1700000000, + } + return doc + + +@pytest.fixture +def mock_container() -> MagicMock: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([])) + container.upsert_item = AsyncMock(return_value={}) + container.delete_item = AsyncMock(return_value={}) + return container + + +@pytest.fixture +def mock_cosmos_client(mock_container: MagicMock) -> MagicMock: + database_client = MagicMock() + database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container) + + client = MagicMock() + client.create_database_if_not_exists = AsyncMock(return_value=database_client) + client.close = AsyncMock() + return client + + +# --- Tests for initialization --- + + +async def test_init_uses_provided_container_client(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + assert storage.database_name == "" + assert storage.container_name == "" + + +async def test_init_uses_provided_cosmos_client(mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="checkpoints", + ) + assert storage.database_name == "db1" + assert storage.container_name == "checkpoints" + + +async def test_init_missing_required_settings_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False) + monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) + + with pytest.raises(SettingNotFoundError, match="database_name"): + CosmosCheckpointStorage() + + +async def test_init_constructs_client_with_credential( + monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock +) -> None: + """Uses key-based auth when a key string is provided, otherwise falls back to Azure credential (RBAC).""" + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) + + # Simulate real-world pattern: use key if available, else RBAC credential + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + credential: Any = cosmos_key if cosmos_key else MagicMock() # MagicMock simulates DefaultAzureCredential() + + CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential=credential, + database_name="db1", + container_name="checkpoints", + ) + + mock_factory.assert_called_once() + kwargs = mock_factory.call_args.kwargs + assert kwargs["url"] == "https://account.documents.azure.com:443/" + assert kwargs["credential"] is credential + + +async def test_init_creates_database_and_container(mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="custom-checkpoints", + ) + + await storage.list_checkpoint_ids(workflow_name="wf") + + mock_cosmos_client.create_database_if_not_exists.assert_awaited_once_with(id="db1") + database_client = mock_cosmos_client.create_database_if_not_exists.return_value + assert database_client.create_container_if_not_exists.await_count == 1 + kwargs = database_client.create_container_if_not_exists.await_args.kwargs + assert kwargs["id"] == "custom-checkpoints" + + +# --- Tests for save --- + + +async def test_save_upserts_document(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + checkpoint = _make_checkpoint() + + result = await storage.save(checkpoint) + + assert result == checkpoint.checkpoint_id + mock_container.upsert_item.assert_awaited_once() + document = mock_container.upsert_item.await_args.kwargs["body"] + assert document["id"] == f"test-workflow_{checkpoint.checkpoint_id}" + assert document["workflow_name"] == "test-workflow" + assert document["graph_signature_hash"] == "abc123" + assert document["state"]["counter"] == 42 + + +async def test_save_returns_checkpoint_id(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + checkpoint = _make_checkpoint(checkpoint_id="cp-123") + + result = await storage.save(checkpoint) + + assert result == "cp-123" + + +# --- Tests for load --- + + +async def test_load_returns_checkpoint(mock_container: MagicMock) -> None: + checkpoint = _make_checkpoint(checkpoint_id="cp-load") + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + loaded = await storage.load("cp-load") + + assert loaded.checkpoint_id == "cp-load" + assert loaded.workflow_name == "test-workflow" + assert loaded.graph_signature_hash == "abc123" + assert loaded.state["counter"] == 42 + + +async def test_load_nonexistent_raises(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"): + await storage.load("nonexistent-id") + + +async def test_load_queries_without_partition_key(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + with suppress(WorkflowCheckpointException): + await storage.load("cp-id") + + kwargs = mock_container.query_items.call_args.kwargs + assert "partition_key" not in kwargs + + +async def test_load_multiple_workflows_same_checkpoint_id_raises(mock_container: MagicMock) -> None: + cp1 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-a") + cp2 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-b") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp1), + _checkpoint_to_cosmos_document(cp2), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="Multiple checkpoints found"): + await storage.load("shared-id") + + +# --- Tests for list_checkpoints --- + + +async def test_list_checkpoints_returns_checkpoints_for_workflow(mock_container: MagicMock) -> None: + cp1 = _make_checkpoint(checkpoint_id="cp-1", timestamp="2025-01-01T00:00:00+00:00") + cp2 = _make_checkpoint(checkpoint_id="cp-2", timestamp="2025-01-02T00:00:00+00:00") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp1), + _checkpoint_to_cosmos_document(cp2), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert len(results) == 2 + assert results[0].checkpoint_id == "cp-1" + assert results[1].checkpoint_id == "cp-2" + + +async def test_list_checkpoints_uses_partition_key(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.list_checkpoints(workflow_name="my-workflow") + + kwargs = mock_container.query_items.call_args.kwargs + assert kwargs["partition_key"] == "my-workflow" + + +async def test_list_checkpoints_empty_returns_empty(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert results == [] + + +# --- Tests for delete --- + + +async def test_delete_existing_returns_true(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"id": "test-workflow_cp-del", "workflow_name": "test-workflow"}, + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("cp-del") + + assert result is True + mock_container.delete_item.assert_awaited_once_with( + item="test-workflow_cp-del", + partition_key="test-workflow", + ) + + +async def test_delete_nonexistent_returns_false(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("nonexistent") + + assert result is False + mock_container.delete_item.assert_not_awaited() + + +async def test_delete_cosmos_not_found_returns_false(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"id": "test-workflow_cp-del", "workflow_name": "test-workflow"}, + ]) + mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("cp-del") + + assert result is False + + +# --- Tests for get_latest --- + + +async def test_get_latest_returns_latest_checkpoint(mock_container: MagicMock) -> None: + cp = _make_checkpoint(checkpoint_id="cp-latest", timestamp="2025-06-01T00:00:00+00:00") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is not None + assert result.checkpoint_id == "cp-latest" + + +async def test_get_latest_returns_none_when_empty(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is None + + +async def test_get_latest_uses_order_by_desc_with_limit(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.get_latest(workflow_name="test-workflow") + + kwargs = mock_container.query_items.call_args.kwargs + assert "ORDER BY c.timestamp DESC" in kwargs["query"] + assert "OFFSET 0 LIMIT 1" in kwargs["query"] + + +# --- Tests for list_checkpoint_ids --- + + +async def test_list_checkpoint_ids_returns_ids(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"checkpoint_id": "cp-1"}, + {"checkpoint_id": "cp-2"}, + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") + + assert ids == ["cp-1", "cp-2"] + + +async def test_list_checkpoint_ids_empty_returns_empty(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") + + assert ids == [] + + +# --- Tests for close and context manager --- + + +async def test_close_closes_owned_client( + monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock +) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + + storage = CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="checkpoints", + ) + + await storage.close() + + mock_cosmos_client.close.assert_awaited_once() + + +async def test_close_does_not_close_external_client(mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="checkpoints", + ) + + await storage.close() + + mock_cosmos_client.close.assert_not_awaited() + + +async def test_context_manager_closes_owned_client( + monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock +) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + + async with CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="checkpoints", + ) as storage: + assert storage is not None + + mock_cosmos_client.close.assert_awaited_once() + + +async def test_context_manager_preserves_original_exception(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + + with ( + patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))), + pytest.raises(ValueError, match="inner error"), + ): + async with storage: + raise ValueError("inner error") + + +# --- Tests for save/load round-trip --- + + +async def test_round_trip_preserves_data(mock_container: MagicMock) -> None: + checkpoint = _make_checkpoint( + checkpoint_id="cp-roundtrip", + previous_checkpoint_id="cp-parent", + ) + checkpoint.state = {"key": "value", "nested": {"a": 1}} + checkpoint.metadata = {"superstep": 3} + checkpoint.iteration_count = 5 + + saved_doc: dict[str, Any] = {} + + async def capture_upsert(body: dict[str, Any]) -> dict[str, Any]: + saved_doc.update(body) + return body + + mock_container.upsert_item = AsyncMock(side_effect=capture_upsert) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.save(checkpoint) + + returned_doc = { + **saved_doc, + "_rid": "abc", + "_self": "dbs/abc/colls/def/docs/ghi", + "_etag": '"etag"', + "_attachments": "attachments/", + "_ts": 1700000000, + } + mock_container.query_items.return_value = _to_async_iter([returned_doc]) + + loaded = await storage.load("cp-roundtrip") + + assert loaded.checkpoint_id == checkpoint.checkpoint_id + assert loaded.workflow_name == checkpoint.workflow_name + assert loaded.graph_signature_hash == checkpoint.graph_signature_hash + assert loaded.previous_checkpoint_id == "cp-parent" + assert loaded.state == {"key": "value", "nested": {"a": 1}} + assert loaded.metadata == {"superstep": 3} + assert loaded.iteration_count == 5 + assert loaded.version == "1.0" + + +# --- Integration test --- + + +@pytest.mark.integration +@skip_if_cosmos_integration_tests_disabled +async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None: + endpoint = os.getenv("AZURE_COSMOS_ENDPOINT", "") + key = os.getenv("AZURE_COSMOS_KEY", "") + database_prefix = os.getenv("AZURE_COSMOS_DATABASE_NAME", "") + container_prefix = os.getenv("AZURE_COSMOS_CONTAINER_NAME", "") + unique = uuid.uuid4().hex[:8] + database_name = f"{database_prefix}-cp-{unique}" + container_name = f"{container_prefix}-cp-{unique}" + + async with CosmosClient(url=endpoint, credential=key) as cosmos_client: + await cosmos_client.create_database_if_not_exists(id=database_name) + + storage = CosmosCheckpointStorage( + cosmos_client=cosmos_client, + database_name=database_name, + container_name=container_name, + ) + + try: + # Save two checkpoints for the same workflow + cp1 = _make_checkpoint( + checkpoint_id="cp-int-1", + workflow_name="integration-wf", + timestamp="2025-01-01T00:00:00+00:00", + ) + cp2 = _make_checkpoint( + checkpoint_id="cp-int-2", + workflow_name="integration-wf", + previous_checkpoint_id="cp-int-1", + timestamp="2025-01-02T00:00:00+00:00", + ) + cp2.state = {"step": 2} + + await storage.save(cp1) + await storage.save(cp2) + + # Load by ID + loaded = await storage.load("cp-int-1") + assert loaded.checkpoint_id == "cp-int-1" + assert loaded.workflow_name == "integration-wf" + + # List all checkpoints for workflow + all_cps = await storage.list_checkpoints(workflow_name="integration-wf") + assert len(all_cps) == 2 + + # List checkpoint IDs + ids = await storage.list_checkpoint_ids(workflow_name="integration-wf") + assert "cp-int-1" in ids + assert "cp-int-2" in ids + + # Get latest + latest = await storage.get_latest(workflow_name="integration-wf") + assert latest is not None + assert latest.checkpoint_id == "cp-int-2" + assert latest.state == {"step": 2} + + # Delete + assert await storage.delete("cp-int-1") is True + assert await storage.delete("cp-int-1") is False + + remaining = await storage.list_checkpoint_ids(workflow_name="integration-wf") + assert remaining == ["cp-int-2"] + + # Cross-workflow isolation + other_cp = _make_checkpoint( + checkpoint_id="cp-other", + workflow_name="other-wf", + ) + await storage.save(other_cp) + wf_cps = await storage.list_checkpoints(workflow_name="integration-wf") + assert len(wf_cps) == 1 + assert wf_cps[0].checkpoint_id == "cp-int-2" + + finally: + with suppress(Exception): + await cosmos_client.delete_database(database_name)