From 8c61234f308448d3e3fd1290aab5ef44a2b15371 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Wed, 25 Mar 2026 13:11:05 -0700 Subject: [PATCH 1/8] Add CosmosCheckpointStorage for Python workflow checkpointing Add native Cosmos DB NoSQL support for workflow checkpoint storage in the Python agent-framework-azure-cosmos package, achieving parity with the existing .NET CosmosCheckpointStore. New files: - _checkpoint_storage.py: CosmosCheckpointStorage implementing the CheckpointStorage protocol with 6 methods (save, load, list_checkpoints, delete, get_latest, list_checkpoint_ids) - test_cosmos_checkpoint_storage.py: Unit and integration tests - workflow_checkpointing.py: Sample demonstrating Cosmos DB-backed workflow checkpoint/resume Auth support: - Managed identity / RBAC via Azure credential objects (DefaultAzureCredential, ManagedIdentityCredential, etc.) - Key-based auth via account key string or AZURE_COSMOS_KEY env var - Pre-created CosmosClient or ContainerProxy Key design decisions: - Partition key: /workflow_name for efficient per-workflow queries - Serialization: Reuses encode/decode_checkpoint_value for full Python object fidelity (hybrid JSON + pickle approach) - Container auto-creation via create_container_if_not_exists Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/packages/azure-cosmos/README.md | 90 ++- .../agent_framework_azure_cosmos/__init__.py | 2 + .../_checkpoint_storage.py | 402 +++++++++++++ .../tests/test_cosmos_checkpoint_storage.py | 531 ++++++++++++++++++ .../conversations/workflow_checkpointing.py | 205 +++++++ 5 files changed, 1223 insertions(+), 7 deletions(-) create mode 100644 python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py create mode 100644 python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py create mode 100644 python/samples/02-agents/conversations/workflow_checkpointing.py diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md index d2868c78b7..8a46b0ca7a 100644 --- a/python/packages/azure-cosmos/README.md +++ b/python/packages/azure-cosmos/README.md @@ -14,7 +14,7 @@ The Azure Cosmos DB integration provides `CosmosHistoryProvider` for persistent ```python from azure.identity.aio import DefaultAzureCredential -from agent_framework.azure import CosmosHistoryProvider +from agent_framework_azure_cosmos import CosmosHistoryProvider provider = CosmosHistoryProvider( endpoint="https://.documents.azure.com:443/", @@ -35,13 +35,89 @@ Container naming behavior: - Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`) - `session_id` is used as the Cosmos partition key for reads/writes -See the [conversation samples](../../samples/02-agents/conversations/) for runnable examples, including -[`cosmos_history_provider.py`](../../samples/02-agents/conversations/cosmos_history_provider.py). +See `samples/cosmos_history_provider.py` for a runnable package-local example. + +## Cosmos DB Workflow Checkpoint Storage + +`CosmosCheckpointStorage` implements the `CheckpointStorage` protocol, enabling +durable workflow checkpointing backed by Azure Cosmos DB NoSQL. Workflows can be +paused and resumed across process restarts by persisting checkpoint state in Cosmos DB. -## Import Paths +### Basic Usage + +#### Managed Identity / RBAC (recommended for production) ```python -from agent_framework.azure import CosmosHistoryProvider -# or directly: -from agent_framework_azure_cosmos import CosmosHistoryProvider +from azure.identity.aio import DefaultAzureCredential +from agent_framework import WorkflowBuilder +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +checkpoint_storage = CosmosCheckpointStorage( + endpoint="https://.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-framework", + container_name="workflow-checkpoints", +) +``` + +#### Account Key + +```python +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +checkpoint_storage = CosmosCheckpointStorage( + endpoint="https://.documents.azure.com:443/", + credential="", + database_name="agent-framework", + container_name="workflow-checkpoints", +) +``` + +#### Then use with a workflow + +```python +from agent_framework import WorkflowBuilder + +# Build a workflow with checkpointing enabled +workflow = WorkflowBuilder( + start_executor=start, + checkpoint_storage=checkpoint_storage, +).build() + +# Run the workflow — checkpoints are automatically saved after each superstep +result = await workflow.run(message="input data") + +# Resume from a checkpoint +latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) +if latest: + resumed = await workflow.run(checkpoint_id=latest.checkpoint_id) ``` + +### Authentication Options + +`CosmosCheckpointStorage` supports the same authentication modes as `CosmosHistoryProvider`: + +- **Managed identity / RBAC** (recommended): Pass `DefaultAzureCredential()`, + `ManagedIdentityCredential()`, or any Azure `TokenCredential` +- **Account key**: Pass a key string via `credential` parameter +- **Environment variables**: Set `AZURE_COSMOS_ENDPOINT`, `AZURE_COSMOS_DATABASE_NAME`, + `AZURE_COSMOS_CONTAINER_NAME`, and `AZURE_COSMOS_KEY` (key not required when using + Azure credentials) +- **Pre-created client**: Pass an existing `CosmosClient` or `ContainerProxy` + +### Container Setup + +The container is created automatically on first use with `/workflow_name` as the +partition key. You can also pre-create the container in the Azure portal with this +partition key configuration. + +### Environment Variables + +| Variable | Description | +|---|---| +| `AZURE_COSMOS_ENDPOINT` | Cosmos DB account endpoint | +| `AZURE_COSMOS_DATABASE_NAME` | Database name | +| `AZURE_COSMOS_CONTAINER_NAME` | Container name | +| `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) | + +See `samples/workflow_checkpointing.py` for a complete runnable example. diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py index 5bcfb3928b..66373b0f1d 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/__init__.py @@ -2,6 +2,7 @@ import importlib.metadata +from ._checkpoint_storage import CosmosCheckpointStorage from ._history_provider import CosmosHistoryProvider try: @@ -10,6 +11,7 @@ __version__ = "0.0.0" # Fallback for development mode __all__ = [ + "CosmosCheckpointStorage", "CosmosHistoryProvider", "__version__", ] diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py new file mode 100644 index 0000000000..5c567a1645 --- /dev/null +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -0,0 +1,402 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Azure Cosmos DB checkpoint storage for workflow checkpointing.""" + +from __future__ import annotations + +import logging +from typing import Any, TypedDict + +from agent_framework import AGENT_FRAMEWORK_USER_AGENT +from agent_framework._settings import SecretString, load_settings +from agent_framework._workflows._checkpoint import CheckpointID, WorkflowCheckpoint +from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from agent_framework.azure._entra_id_authentication import AzureCredentialTypes +from agent_framework.exceptions import WorkflowCheckpointException +from azure.cosmos import PartitionKey +from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +logger = logging.getLogger(__name__) + + +class AzureCosmosCheckpointSettings(TypedDict, total=False): + """Settings for CosmosCheckpointStorage resolved from args and environment.""" + + endpoint: str | None + database_name: str | None + container_name: str | None + key: SecretString | None + + +class CosmosCheckpointStorage: + """Azure Cosmos DB-backed checkpoint storage for workflow checkpointing. + + Implements the ``CheckpointStorage`` protocol using Azure Cosmos DB NoSQL + as the persistent backend. Checkpoints are stored as JSON documents with + ``workflow_name`` as the partition key, enabling efficient per-workflow queries. + + This storage uses the same hybrid JSON + pickle encoding as + ``FileCheckpointStorage``, allowing full Python object fidelity for + complex workflow state while keeping the document structure human-readable. + + SECURITY WARNING: Checkpoints use pickle for data serialization. Only load + checkpoints from trusted sources. Loading a malicious checkpoint can execute + arbitrary code. + + The container is created automatically on first use with partition key + ``/workflow_name`` if it does not already exist. + + Example using managed identity / RBAC:: + + from azure.identity.aio import DefaultAzureCredential + from agent_framework_azure_cosmos import CosmosCheckpointStorage + + storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-db", + container_name="checkpoints", + ) + + Example using account key:: + + storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential="my-account-key", + database_name="agent-db", + container_name="checkpoints", + ) + + Then use with a workflow builder:: + + workflow = WorkflowBuilder( + start_executor=start, + checkpoint_storage=storage, + ).build() + """ + + def __init__( + self, + *, + endpoint: str | None = None, + database_name: str | None = None, + container_name: str | None = None, + credential: str | AzureCredentialTypes | None = None, + cosmos_client: CosmosClient | None = None, + container_client: ContainerProxy | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initialize the Azure Cosmos DB checkpoint storage. + + Supports multiple authentication modes: + + - **Container client** (``container_client``): Use a pre-created + Cosmos async container proxy. No client lifecycle is managed. + - **Cosmos client** (``cosmos_client``): Use a pre-created Cosmos + async client. The caller is responsible for closing it. + - **Endpoint + credential**: Create a new Cosmos client. The storage + owns the client and closes it on ``close()``. + - **Environment variables**: Falls back to ``AZURE_COSMOS_ENDPOINT``, + ``AZURE_COSMOS_DATABASE_NAME``, ``AZURE_COSMOS_CONTAINER_NAME``, + and ``AZURE_COSMOS_KEY``. + + Args: + endpoint: Cosmos DB account endpoint. + Can be set via ``AZURE_COSMOS_ENDPOINT``. + database_name: Cosmos DB database name. + Can be set via ``AZURE_COSMOS_DATABASE_NAME``. + container_name: Cosmos DB container name. + Can be set via ``AZURE_COSMOS_CONTAINER_NAME``. + credential: Credential to authenticate with Cosmos DB. + For **managed identity / RBAC**, pass an Azure credential object + such as ``DefaultAzureCredential()`` or + ``ManagedIdentityCredential()``. + For **key-based auth**, pass the account key as a string, + or set ``AZURE_COSMOS_KEY`` in the environment. + cosmos_client: Pre-created Cosmos async client. + container_client: Pre-created Cosmos container client. + env_file_path: Path to environment file for loading settings. + env_file_encoding: Encoding of the environment file. + """ + self._cosmos_client: CosmosClient | None = cosmos_client + self._container_proxy: ContainerProxy | None = container_client + self._owns_client = False + self._database_client: DatabaseProxy | None = None + + if self._container_proxy is not None: + self.database_name: str = database_name or "" + self.container_name: str = container_name or "" + return + + required_fields: list[str] = ["database_name", "container_name"] + if cosmos_client is None: + required_fields.append("endpoint") + if credential is None: + required_fields.append("key") + + settings = load_settings( + AzureCosmosCheckpointSettings, + env_prefix="AZURE_COSMOS_", + required_fields=required_fields, + endpoint=endpoint, + database_name=database_name, + container_name=container_name, + key=credential if isinstance(credential, str) else None, + env_file_path=env_file_path, + env_file_encoding=env_file_encoding, + ) + self.database_name = settings["database_name"] # type: ignore[assignment] + self.container_name = settings["container_name"] # type: ignore[assignment] + + if self._cosmos_client is None: + self._cosmos_client = CosmosClient( + url=settings["endpoint"], # type: ignore[arg-type] + credential=credential or settings["key"].get_secret_value(), # type: ignore[arg-type,union-attr] + user_agent_suffix=AGENT_FRAMEWORK_USER_AGENT, + ) + self._owns_client = True + + self._database_client = self._cosmos_client.get_database_client(self.database_name) + + async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID: + """Save a checkpoint to Cosmos DB and return its ID. + + The checkpoint is encoded to a JSON-compatible form (using pickle for + non-JSON-native values) and stored as a Cosmos DB document with the + ``workflow_name`` as the partition key. + + Args: + checkpoint: The WorkflowCheckpoint object to save. + + Returns: + The unique ID of the saved checkpoint. + """ + await self._ensure_container_proxy() + + checkpoint_dict = checkpoint.to_dict() + encoded = encode_checkpoint_value(checkpoint_dict) + + document: dict[str, Any] = { + "id": checkpoint.checkpoint_id, + "workflow_name": checkpoint.workflow_name, + **encoded, + } + + await self._container_proxy.upsert_item(body=document) # type: ignore[union-attr] + logger.info("Saved checkpoint %s to Cosmos DB", checkpoint.checkpoint_id) + return checkpoint.checkpoint_id + + async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint: + """Load a checkpoint from Cosmos DB by ID. + + Args: + checkpoint_id: The unique ID of the checkpoint to load. + + Returns: + The WorkflowCheckpoint object corresponding to the given ID. + + Raises: + WorkflowCheckpointException: If no checkpoint with the given ID exists. + """ + await self._ensure_container_proxy() + + query = "SELECT * FROM c WHERE c.id = @checkpoint_id" + parameters: list[dict[str, object]] = [ + {"name": "@checkpoint_id", "value": checkpoint_id}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + enable_cross_partition_query=True, + ) + + async for item in items: + return self._document_to_checkpoint(item) + + raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}") + + async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: + """List checkpoint objects for a given workflow name. + + Args: + workflow_name: The name of the workflow to list checkpoints for. + + Returns: + A list of WorkflowCheckpoint objects for the specified workflow name. + """ + await self._ensure_container_proxy() + + query = "SELECT * FROM c WHERE c.workflow_name = @workflow_name ORDER BY c.timestamp ASC" + parameters: list[dict[str, object]] = [ + {"name": "@workflow_name", "value": workflow_name}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + partition_key=workflow_name, + ) + + checkpoints: list[WorkflowCheckpoint] = [] + async for item in items: + try: + checkpoints.append(self._document_to_checkpoint(item)) + except Exception as e: + logger.warning("Failed to decode checkpoint document: %s", e) + return checkpoints + + async def delete(self, checkpoint_id: CheckpointID) -> bool: + """Delete a checkpoint from Cosmos DB by ID. + + Args: + checkpoint_id: The unique ID of the checkpoint to delete. + + Returns: + True if the checkpoint was successfully deleted, False if not found. + """ + await self._ensure_container_proxy() + + # We need to find the document first to get its partition key + query = "SELECT c.id, c.workflow_name FROM c WHERE c.id = @checkpoint_id" + parameters: list[dict[str, object]] = [ + {"name": "@checkpoint_id", "value": checkpoint_id}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + enable_cross_partition_query=True, + ) + + async for item in items: + try: + await self._container_proxy.delete_item( # type: ignore[union-attr] + item=checkpoint_id, + partition_key=item["workflow_name"], + ) + logger.info("Deleted checkpoint %s from Cosmos DB", checkpoint_id) + return True + except CosmosResourceNotFoundError: + return False + + return False + + async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: + """Get the latest checkpoint for a given workflow name. + + Args: + workflow_name: The name of the workflow to get the latest checkpoint for. + + Returns: + The latest WorkflowCheckpoint, or None if no checkpoints exist. + """ + await self._ensure_container_proxy() + + query = ( + "SELECT * FROM c WHERE c.workflow_name = @workflow_name " + "ORDER BY c.timestamp DESC OFFSET 0 LIMIT 1" + ) + parameters: list[dict[str, object]] = [ + {"name": "@workflow_name", "value": workflow_name}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + partition_key=workflow_name, + ) + + async for item in items: + checkpoint = self._document_to_checkpoint(item) + logger.debug( + "Latest checkpoint for workflow %s is %s", + workflow_name, + checkpoint.checkpoint_id, + ) + return checkpoint + + return None + + async def list_checkpoint_ids(self, *, workflow_name: str) -> list[CheckpointID]: + """List checkpoint IDs for a given workflow name. + + Args: + workflow_name: The name of the workflow to list checkpoint IDs for. + + Returns: + A list of checkpoint IDs for the specified workflow name. + """ + await self._ensure_container_proxy() + + query = ( + "SELECT c.checkpoint_id FROM c WHERE c.workflow_name = @workflow_name " + "ORDER BY c.timestamp ASC" + ) + parameters: list[dict[str, object]] = [ + {"name": "@workflow_name", "value": workflow_name}, + ] + + items = self._container_proxy.query_items( # type: ignore[union-attr] + query=query, + parameters=parameters, + partition_key=workflow_name, + ) + + checkpoint_ids: list[CheckpointID] = [] + async for item in items: + cid = item.get("checkpoint_id") + if isinstance(cid, str): + checkpoint_ids.append(cid) + return checkpoint_ids + + async def close(self) -> None: + """Close the underlying Cosmos client when this storage owns it.""" + if self._owns_client and self._cosmos_client is not None: + await self._cosmos_client.close() + + async def __aenter__(self) -> CosmosCheckpointStorage: + """Async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + """Async context manager exit.""" + try: + await self.close() + except Exception: + if exc_type is None: + raise + + async def _ensure_container_proxy(self) -> None: + """Get or create the Cosmos DB container for storing checkpoints.""" + if self._container_proxy is not None: + return + if self._database_client is None: + raise RuntimeError("Cosmos database client is not initialized.") + + self._container_proxy = await self._database_client.create_container_if_not_exists( + id=self.container_name, + partition_key=PartitionKey(path="/workflow_name"), + ) + + @staticmethod + def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: + """Convert a Cosmos DB document back to a WorkflowCheckpoint. + + Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``, + ``_attachments``, ``_ts``) before decoding. + """ + # Remove Cosmos DB system properties and the 'id' field + # (checkpoints use 'checkpoint_id', not 'id') + cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"} + cleaned = {k: v for k, v in document.items() if k not in cosmos_keys} + + decoded = decode_checkpoint_value(cleaned) + return WorkflowCheckpoint.from_dict(decoded) diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py new file mode 100644 index 0000000000..34bbf46b0a --- /dev/null +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -0,0 +1,531 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import os +import uuid +from collections.abc import AsyncIterator +from contextlib import suppress +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from agent_framework._workflows._checkpoint import WorkflowCheckpoint +from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value +from agent_framework.exceptions import SettingNotFoundError, WorkflowCheckpointException +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +import agent_framework_azure_cosmos._checkpoint_storage as checkpoint_storage_module +from agent_framework_azure_cosmos._checkpoint_storage import CosmosCheckpointStorage + +skip_if_cosmos_integration_tests_disabled = pytest.mark.skipif( + any( + os.getenv(name, "") == "" + for name in ( + "AZURE_COSMOS_ENDPOINT", + "AZURE_COSMOS_KEY", + "AZURE_COSMOS_DATABASE_NAME", + "AZURE_COSMOS_CONTAINER_NAME", + ) + ), + reason=( + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_KEY, AZURE_COSMOS_DATABASE_NAME, and " + "AZURE_COSMOS_CONTAINER_NAME are required for Cosmos integration tests." + ), +) + + +def _to_async_iter(items: list[Any]) -> AsyncIterator[Any]: + async def _iterator() -> AsyncIterator[Any]: + for item in items: + yield item + + return _iterator() + + +def _make_checkpoint( + workflow_name: str = "test-workflow", + checkpoint_id: str | None = None, + previous_checkpoint_id: str | None = None, + timestamp: str | None = None, +) -> WorkflowCheckpoint: + """Create a minimal WorkflowCheckpoint for testing.""" + return WorkflowCheckpoint( + workflow_name=workflow_name, + graph_signature_hash="abc123", + checkpoint_id=checkpoint_id or str(uuid.uuid4()), + previous_checkpoint_id=previous_checkpoint_id, + timestamp=timestamp or "2025-01-01T00:00:00+00:00", + state={"counter": 42}, + iteration_count=1, + ) + + +def _checkpoint_to_cosmos_document(checkpoint: WorkflowCheckpoint) -> dict[str, Any]: + """Simulate what a Cosmos DB document looks like after save.""" + encoded = encode_checkpoint_value(checkpoint.to_dict()) + doc: dict[str, Any] = { + "id": checkpoint.checkpoint_id, + "workflow_name": checkpoint.workflow_name, + **encoded, + # Cosmos system properties + "_rid": "abc", + "_self": "dbs/abc/colls/def/docs/ghi", + "_etag": '"00000000-0000-0000-0000-000000000000"', + "_attachments": "attachments/", + "_ts": 1700000000, + } + return doc + + +@pytest.fixture +def mock_container() -> MagicMock: + container = MagicMock() + container.query_items = MagicMock(return_value=_to_async_iter([])) + container.upsert_item = AsyncMock(return_value={}) + container.delete_item = AsyncMock(return_value={}) + return container + + +@pytest.fixture +def mock_cosmos_client(mock_container: MagicMock) -> MagicMock: + database_client = MagicMock() + database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container) + + client = MagicMock() + client.get_database_client.return_value = database_client + client.close = AsyncMock() + return client + + +class TestCosmosCheckpointStorageInit: + def test_uses_provided_container_client(self, mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + assert storage.database_name == "" + assert storage.container_name == "" + + def test_uses_provided_cosmos_client(self, mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="checkpoints", + ) + + mock_cosmos_client.get_database_client.assert_called_once_with("db1") + assert storage.database_name == "db1" + assert storage.container_name == "checkpoints" + + def test_missing_required_settings_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False) + monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) + + with pytest.raises(SettingNotFoundError, match="database_name"): + CosmosCheckpointStorage() + + def test_constructs_client_with_credential( + self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock + ) -> None: + """Uses key-based auth when a key string is provided, otherwise falls back to Azure credential (RBAC).""" + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) + + # Simulate real-world pattern: use key if available, else RBAC credential + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + credential: Any = cosmos_key if cosmos_key else MagicMock() # MagicMock simulates DefaultAzureCredential() + + CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential=credential, + database_name="db1", + container_name="checkpoints", + ) + + mock_factory.assert_called_once() + kwargs = mock_factory.call_args.kwargs + assert kwargs["url"] == "https://account.documents.azure.com:443/" + assert kwargs["credential"] is credential + + +class TestCosmosCheckpointStorageContainerConfig: + async def test_container_name_is_used(self, mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="custom-checkpoints", + ) + + await storage.list_checkpoint_ids(workflow_name="wf") + + database_client = mock_cosmos_client.get_database_client.return_value + assert database_client.create_container_if_not_exists.await_count == 1 + kwargs = database_client.create_container_if_not_exists.await_args.kwargs + assert kwargs["id"] == "custom-checkpoints" + + +class TestCosmosCheckpointStorageSave: + async def test_save_upserts_document(self, mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + checkpoint = _make_checkpoint() + + result = await storage.save(checkpoint) + + assert result == checkpoint.checkpoint_id + mock_container.upsert_item.assert_awaited_once() + document = mock_container.upsert_item.await_args.kwargs["body"] + assert document["id"] == checkpoint.checkpoint_id + assert document["workflow_name"] == "test-workflow" + assert document["graph_signature_hash"] == "abc123" + assert document["state"]["counter"] == 42 + + async def test_save_returns_checkpoint_id(self, mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + checkpoint = _make_checkpoint(checkpoint_id="cp-123") + + result = await storage.save(checkpoint) + + assert result == "cp-123" + + +class TestCosmosCheckpointStorageLoad: + async def test_load_returns_checkpoint(self, mock_container: MagicMock) -> None: + checkpoint = _make_checkpoint(checkpoint_id="cp-load") + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + loaded = await storage.load("cp-load") + + assert loaded.checkpoint_id == "cp-load" + assert loaded.workflow_name == "test-workflow" + assert loaded.graph_signature_hash == "abc123" + assert loaded.state["counter"] == 42 + + async def test_load_nonexistent_raises(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"): + await storage.load("nonexistent-id") + + async def test_load_uses_cross_partition_query(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + with suppress(WorkflowCheckpointException): + await storage.load("cp-id") + + kwargs = mock_container.query_items.call_args.kwargs + assert kwargs["enable_cross_partition_query"] is True + + +class TestCosmosCheckpointStorageListCheckpoints: + async def test_returns_checkpoints_for_workflow(self, mock_container: MagicMock) -> None: + cp1 = _make_checkpoint(checkpoint_id="cp-1", timestamp="2025-01-01T00:00:00+00:00") + cp2 = _make_checkpoint(checkpoint_id="cp-2", timestamp="2025-01-02T00:00:00+00:00") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp1), + _checkpoint_to_cosmos_document(cp2), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert len(results) == 2 + assert results[0].checkpoint_id == "cp-1" + assert results[1].checkpoint_id == "cp-2" + + async def test_uses_partition_key(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.list_checkpoints(workflow_name="my-workflow") + + kwargs = mock_container.query_items.call_args.kwargs + assert kwargs["partition_key"] == "my-workflow" + + async def test_empty_returns_empty(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert results == [] + + +class TestCosmosCheckpointStorageDelete: + async def test_delete_existing_returns_true(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"id": "cp-del", "workflow_name": "test-workflow"}, + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("cp-del") + + assert result is True + mock_container.delete_item.assert_awaited_once_with( + item="cp-del", + partition_key="test-workflow", + ) + + async def test_delete_nonexistent_returns_false(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("nonexistent") + + assert result is False + mock_container.delete_item.assert_not_awaited() + + async def test_delete_cosmos_not_found_returns_false(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"id": "cp-del", "workflow_name": "test-workflow"}, + ]) + mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("cp-del") + + assert result is False + + +class TestCosmosCheckpointStorageGetLatest: + async def test_returns_latest_checkpoint(self, mock_container: MagicMock) -> None: + cp = _make_checkpoint(checkpoint_id="cp-latest", timestamp="2025-06-01T00:00:00+00:00") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is not None + assert result.checkpoint_id == "cp-latest" + + async def test_returns_none_when_empty(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is None + + async def test_uses_order_by_desc_with_limit(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.get_latest(workflow_name="test-workflow") + + kwargs = mock_container.query_items.call_args.kwargs + assert "ORDER BY c.timestamp DESC" in kwargs["query"] + assert "OFFSET 0 LIMIT 1" in kwargs["query"] + + +class TestCosmosCheckpointStorageListIds: + async def test_returns_checkpoint_ids(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"checkpoint_id": "cp-1"}, + {"checkpoint_id": "cp-2"}, + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") + + assert ids == ["cp-1", "cp-2"] + + async def test_empty_returns_empty(self, mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") + + assert ids == [] + + +class TestCosmosCheckpointStorageClose: + async def test_close_closes_owned_client( + self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock + ) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + + storage = CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="checkpoints", + ) + + await storage.close() + + mock_cosmos_client.close.assert_awaited_once() + + async def test_close_does_not_close_external_client(self, mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="checkpoints", + ) + + await storage.close() + + mock_cosmos_client.close.assert_not_awaited() + + async def test_async_context_manager_closes_owned_client( + self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock + ) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + + async with CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="checkpoints", + ) as storage: + assert storage is not None + + mock_cosmos_client.close.assert_awaited_once() + + async def test_async_context_manager_preserves_original_exception( + self, mock_container: MagicMock + ) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + + with ( + patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))), + pytest.raises(ValueError, match="inner error"), + ): + async with storage: + raise ValueError("inner error") + + +class TestCosmosCheckpointStorageSaveLoadRoundTrip: + async def test_round_trip_preserves_data(self, mock_container: MagicMock) -> None: + """Test that saving and loading a checkpoint preserves all data.""" + checkpoint = _make_checkpoint( + checkpoint_id="cp-roundtrip", + previous_checkpoint_id="cp-parent", + ) + checkpoint.state = {"key": "value", "nested": {"a": 1}} + checkpoint.metadata = {"superstep": 3} + checkpoint.iteration_count = 5 + + # Capture the document that was saved + saved_doc: dict[str, Any] = {} + + async def capture_upsert(body: dict[str, Any]) -> dict[str, Any]: + saved_doc.update(body) + return body + + mock_container.upsert_item = AsyncMock(side_effect=capture_upsert) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.save(checkpoint) + + # Simulate Cosmos returning the saved document with system properties + returned_doc = { + **saved_doc, + "_rid": "abc", + "_self": "dbs/abc/colls/def/docs/ghi", + "_etag": '"etag"', + "_attachments": "attachments/", + "_ts": 1700000000, + } + mock_container.query_items.return_value = _to_async_iter([returned_doc]) + + loaded = await storage.load("cp-roundtrip") + + assert loaded.checkpoint_id == checkpoint.checkpoint_id + assert loaded.workflow_name == checkpoint.workflow_name + assert loaded.graph_signature_hash == checkpoint.graph_signature_hash + assert loaded.previous_checkpoint_id == "cp-parent" + assert loaded.state == {"key": "value", "nested": {"a": 1}} + assert loaded.metadata == {"superstep": 3} + assert loaded.iteration_count == 5 + assert loaded.version == "1.0" + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_cosmos_integration_tests_disabled +async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None: + endpoint = os.getenv("AZURE_COSMOS_ENDPOINT", "") + key = os.getenv("AZURE_COSMOS_KEY", "") + database_prefix = os.getenv("AZURE_COSMOS_DATABASE_NAME", "") + container_prefix = os.getenv("AZURE_COSMOS_CONTAINER_NAME", "") + unique = uuid.uuid4().hex[:8] + database_name = f"{database_prefix}-cp-{unique}" + container_name = f"{container_prefix}-cp-{unique}" + + async with CosmosClient(url=endpoint, credential=key) as cosmos_client: + await cosmos_client.create_database_if_not_exists(id=database_name) + + storage = CosmosCheckpointStorage( + cosmos_client=cosmos_client, + database_name=database_name, + container_name=container_name, + ) + + try: + # Save two checkpoints for the same workflow + cp1 = _make_checkpoint( + checkpoint_id="cp-int-1", + workflow_name="integration-wf", + timestamp="2025-01-01T00:00:00+00:00", + ) + cp2 = _make_checkpoint( + checkpoint_id="cp-int-2", + workflow_name="integration-wf", + previous_checkpoint_id="cp-int-1", + timestamp="2025-01-02T00:00:00+00:00", + ) + cp2.state = {"step": 2} + + await storage.save(cp1) + await storage.save(cp2) + + # Load by ID + loaded = await storage.load("cp-int-1") + assert loaded.checkpoint_id == "cp-int-1" + assert loaded.workflow_name == "integration-wf" + + # List all checkpoints for workflow + all_cps = await storage.list_checkpoints(workflow_name="integration-wf") + assert len(all_cps) == 2 + + # List checkpoint IDs + ids = await storage.list_checkpoint_ids(workflow_name="integration-wf") + assert "cp-int-1" in ids + assert "cp-int-2" in ids + + # Get latest + latest = await storage.get_latest(workflow_name="integration-wf") + assert latest is not None + assert latest.checkpoint_id == "cp-int-2" + assert latest.state == {"step": 2} + + # Delete + assert await storage.delete("cp-int-1") is True + assert await storage.delete("cp-int-1") is False + + remaining = await storage.list_checkpoint_ids(workflow_name="integration-wf") + assert remaining == ["cp-int-2"] + + # Cross-workflow isolation + other_cp = _make_checkpoint( + checkpoint_id="cp-other", + workflow_name="other-wf", + ) + await storage.save(other_cp) + wf_cps = await storage.list_checkpoints(workflow_name="integration-wf") + assert len(wf_cps) == 1 + assert wf_cps[0].checkpoint_id == "cp-int-2" + + finally: + with suppress(Exception): + await cosmos_client.delete_database(database_name) diff --git a/python/samples/02-agents/conversations/workflow_checkpointing.py b/python/samples/02-agents/conversations/workflow_checkpointing.py new file mode 100644 index 0000000000..7285610a5b --- /dev/null +++ b/python/samples/02-agents/conversations/workflow_checkpointing.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +""" +Sample: Workflow Checkpointing with Cosmos DB NoSQL + +Purpose: +This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint +storage backend for workflows, enabling durable pause-and-resume across +process restarts. + +What you learn: +- How to configure CosmosCheckpointStorage for workflow checkpointing +- How to run a workflow that automatically persists checkpoints to Cosmos DB +- How to resume a workflow from a Cosmos DB checkpoint +- How to list and inspect available checkpoints + +Prerequisites: +- An Azure Cosmos DB account (or local emulator) +- Environment variables set (see below) + +Environment variables: + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name + AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +import sys +from dataclasses import dataclass +from typing import Any + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowCheckpoint, + WorkflowContext, + handler, +) + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + +from agent_framework_azure_cosmos import CosmosCheckpointStorage + + +@dataclass +class ComputeTask: + """Task containing the list of numbers remaining to be processed.""" + + remaining_numbers: list[int] + + +class StartExecutor(Executor): + """Initiates the workflow by providing the upper limit.""" + + @handler + async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None: + print(f"StartExecutor: Starting computation up to {upper_limit}") + await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1)))) + + +class WorkerExecutor(Executor): + """Processes numbers and manages executor state for checkpointing.""" + + def __init__(self, id: str) -> None: + super().__init__(id=id) + self._results: dict[int, list[tuple[int, int]]] = {} + + @handler + async def compute( + self, + task: ComputeTask, + ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]], + ) -> None: + next_number = task.remaining_numbers.pop(0) + print(f"WorkerExecutor: Processing {next_number}") + + pairs: list[tuple[int, int]] = [] + for i in range(1, next_number): + if next_number % i == 0: + pairs.append((i, next_number // i)) + self._results[next_number] = pairs + + if not task.remaining_numbers: + await ctx.yield_output(self._results) + else: + await ctx.send_message(task) + + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"results": self._results} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._results = state.get("results", {}) + + +async def main() -> None: + """Run the workflow checkpointing sample with Cosmos DB.""" + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: + print( + "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " + "and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # Authentication: supports both managed identity/RBAC and key-based auth. + # + # Option 1 — Managed identity / RBAC (recommended for production): + # from azure.identity.aio import DefaultAzureCredential + # credential = DefaultAzureCredential() + # + # Option 2 — Account key: + # credential = cosmos_key (or set AZURE_COSMOS_KEY env var) + # + # This sample uses key-based auth when AZURE_COSMOS_KEY is set, + # otherwise falls back to DefaultAzureCredential. + credential: Any + if cosmos_key: + credential = cosmos_key + else: + from azure.identity.aio import DefaultAzureCredential + + credential = DefaultAzureCredential() + + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + # Build workflow with Cosmos DB checkpointing + start = StartExecutor(id="start") + worker = WorkerExecutor(id="worker") + workflow_builder = ( + WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) + .add_edge(start, worker) + .add_edge(worker, worker) + ) + + # --- First run: execute and stop after 3 iterations --- + print("\n=== First Run ===\n") + workflow = workflow_builder.build() + event_stream = workflow.run(message=8, stream=True) + + async for event in event_stream: + if event.type == "output": + print(f"\nWorkflow completed: {event.data}") + break + if event.type == "superstep_completed": + print(f" [superstep completed, iteration {event.data}]") + if event.data >= 3: + print("\n** Stopping after 3 iterations **") + break + + # List checkpoints saved in Cosmos DB + checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( + workflow_name=workflow.name, + ) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") + for cid in checkpoint_ids: + print(f" - {cid}") + + # Get the latest checkpoint + latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( + workflow_name=workflow.name, + ) + + if latest is None: + print("No checkpoint found to resume from.") + return + + print(f"\nLatest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") + + # --- Second run: resume from the latest checkpoint --- + print("\n=== Resuming from Checkpoint ===\n") + workflow2 = workflow_builder.build() + event_stream2 = workflow2.run( + checkpoint_id=latest.checkpoint_id, + stream=True, + ) + + async for event in event_stream2: + if event.type == "output": + print(f"\nWorkflow completed: {event.data}") + break + if event.type == "superstep_completed": + print(f" [superstep completed, iteration {event.data}]") + + +if __name__ == "__main__": + asyncio.run(main()) From b36459b09a8e789ab3d3d7f381583229430c10e4 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Wed, 25 Mar 2026 15:08:52 -0700 Subject: [PATCH 2/8] Adding cosmos checkpointer --- python/packages/azure-cosmos/README.md | 13 +- .../_checkpoint_storage.py | 21 +- .../packages/azure-cosmos/samples/README.md | 24 + .../tests/test_cosmos_checkpoint_storage.py | 599 +++++++++--------- ...ng.py => cosmos_workflow_checkpointing.py} | 33 +- .../cosmos_workflow_checkpointing_foundry.py | 145 +++++ 6 files changed, 514 insertions(+), 321 deletions(-) create mode 100644 python/packages/azure-cosmos/samples/README.md rename python/samples/02-agents/conversations/{workflow_checkpointing.py => cosmos_workflow_checkpointing.py} (87%) create mode 100644 python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md index 8a46b0ca7a..3351f3c7c1 100644 --- a/python/packages/azure-cosmos/README.md +++ b/python/packages/azure-cosmos/README.md @@ -105,11 +105,12 @@ if latest: Azure credentials) - **Pre-created client**: Pass an existing `CosmosClient` or `ContainerProxy` -### Container Setup +### Database and Container Setup -The container is created automatically on first use with `/workflow_name` as the -partition key. You can also pre-create the container in the Azure portal with this -partition key configuration. +The database and container are created automatically on first use (via +`create_database_if_not_exists` and `create_container_if_not_exists`). The container +uses `/workflow_name` as the partition key. You can also pre-create them in the Azure +portal with this partition key configuration. ### Environment Variables @@ -120,4 +121,6 @@ partition key configuration. | `AZURE_COSMOS_CONTAINER_NAME` | Container name | | `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) | -See `samples/workflow_checkpointing.py` for a complete runnable example. +See `samples/cosmos_workflow_checkpointing.py` for a standalone example, or +`samples/cosmos_workflow_checkpointing_foundry.py` for an end-to-end example +with Azure AI Foundry agents. diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index 5c567a1645..ce608c8a52 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -14,7 +14,7 @@ from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import WorkflowCheckpointException from azure.cosmos import PartitionKey -from azure.cosmos.aio import ContainerProxy, CosmosClient, DatabaseProxy +from azure.cosmos.aio import ContainerProxy, CosmosClient from azure.cosmos.exceptions import CosmosResourceNotFoundError logger = logging.getLogger(__name__) @@ -44,8 +44,9 @@ class CosmosCheckpointStorage: checkpoints from trusted sources. Loading a malicious checkpoint can execute arbitrary code. - The container is created automatically on first use with partition key - ``/workflow_name`` if it does not already exist. + The database and container are created automatically on first use + if they do not already exist. The container uses partition key + ``/workflow_name``. Example using managed identity / RBAC:: @@ -123,7 +124,6 @@ def __init__( self._cosmos_client: CosmosClient | None = cosmos_client self._container_proxy: ContainerProxy | None = container_client self._owns_client = False - self._database_client: DatabaseProxy | None = None if self._container_proxy is not None: self.database_name: str = database_name or "" @@ -158,8 +158,6 @@ def __init__( ) self._owns_client = True - self._database_client = self._cosmos_client.get_database_client(self.database_name) - async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID: """Save a checkpoint to Cosmos DB and return its ID. @@ -210,7 +208,6 @@ async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint: items = self._container_proxy.query_items( # type: ignore[union-attr] query=query, parameters=parameters, - enable_cross_partition_query=True, ) async for item in items: @@ -268,7 +265,6 @@ async def delete(self, checkpoint_id: CheckpointID) -> bool: items = self._container_proxy.query_items( # type: ignore[union-attr] query=query, parameters=parameters, - enable_cross_partition_query=True, ) async for item in items: @@ -375,13 +371,14 @@ async def __aexit__( raise async def _ensure_container_proxy(self) -> None: - """Get or create the Cosmos DB container for storing checkpoints.""" + """Get or create the Cosmos DB database and container for storing checkpoints.""" if self._container_proxy is not None: return - if self._database_client is None: - raise RuntimeError("Cosmos database client is not initialized.") + if self._cosmos_client is None: + raise RuntimeError("Cosmos client is not initialized.") - self._container_proxy = await self._database_client.create_container_if_not_exists( + database = await self._cosmos_client.create_database_if_not_exists(id=self.database_name) + self._container_proxy = await database.create_container_if_not_exists( id=self.container_name, partition_key=PartitionKey(path="/workflow_name"), ) diff --git a/python/packages/azure-cosmos/samples/README.md b/python/packages/azure-cosmos/samples/README.md new file mode 100644 index 0000000000..9c767a56a1 --- /dev/null +++ b/python/packages/azure-cosmos/samples/README.md @@ -0,0 +1,24 @@ +# Azure Cosmos DB Package Samples + +This folder contains samples for `agent-framework-azure-cosmos`. + +| File | Description | +| --- | --- | +| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Demonstrates an Agent using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient` (project endpoint), provider-configured container name, and `session_id` partitioning. | +| [`cosmos_workflow_checkpointing.py`](cosmos_workflow_checkpointing.py) | Workflow checkpoint storage with Cosmos DB — pause and resume workflows across restarts using `CosmosCheckpointStorage`, with support for key-based and managed identity auth. | +| [`cosmos_workflow_checkpointing_foundry.py`](cosmos_workflow_checkpointing_foundry.py) | End-to-end Azure AI Foundry + Cosmos DB checkpointing — multi-agent workflow using `AzureOpenAIResponsesClient` with `CosmosCheckpointStorage` for durable pause/resume. | + +## Prerequisites + +- `AZURE_COSMOS_ENDPOINT` +- `AZURE_COSMOS_DATABASE_NAME` +- `AZURE_COSMOS_CONTAINER_NAME` +- `AZURE_COSMOS_KEY` (or equivalent credential flow) + +## Run + +```bash +uv run --directory packages/azure-cosmos python samples/cosmos_history_provider.py +uv run --directory packages/azure-cosmos python samples/cosmos_workflow_checkpointing.py +uv run --directory packages/azure-cosmos python samples/cosmos_workflow_checkpointing_foundry.py +``` diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py index 34bbf46b0a..7734c3060a 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -94,363 +94,392 @@ def mock_cosmos_client(mock_container: MagicMock) -> MagicMock: database_client.create_container_if_not_exists = AsyncMock(return_value=mock_container) client = MagicMock() - client.get_database_client.return_value = database_client + client.create_database_if_not_exists = AsyncMock(return_value=database_client) client.close = AsyncMock() return client -class TestCosmosCheckpointStorageInit: - def test_uses_provided_container_client(self, mock_container: MagicMock) -> None: - storage = CosmosCheckpointStorage(container_client=mock_container) - assert storage.database_name == "" - assert storage.container_name == "" +# --- Tests for initialization --- - def test_uses_provided_cosmos_client(self, mock_cosmos_client: MagicMock) -> None: - storage = CosmosCheckpointStorage( - cosmos_client=mock_cosmos_client, - database_name="db1", - container_name="checkpoints", - ) - mock_cosmos_client.get_database_client.assert_called_once_with("db1") - assert storage.database_name == "db1" - assert storage.container_name == "checkpoints" - - def test_missing_required_settings_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False) - monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False) - monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False) - monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) - - with pytest.raises(SettingNotFoundError, match="database_name"): - CosmosCheckpointStorage() - - def test_constructs_client_with_credential( - self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock - ) -> None: - """Uses key-based auth when a key string is provided, otherwise falls back to Azure credential (RBAC).""" - mock_factory = MagicMock(return_value=mock_cosmos_client) - monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) - monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) - - # Simulate real-world pattern: use key if available, else RBAC credential - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - credential: Any = cosmos_key if cosmos_key else MagicMock() # MagicMock simulates DefaultAzureCredential() - - CosmosCheckpointStorage( - endpoint="https://account.documents.azure.com:443/", - credential=credential, - database_name="db1", - container_name="checkpoints", - ) +async def test_init_uses_provided_container_client(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + assert storage.database_name == "" + assert storage.container_name == "" - mock_factory.assert_called_once() - kwargs = mock_factory.call_args.kwargs - assert kwargs["url"] == "https://account.documents.azure.com:443/" - assert kwargs["credential"] is credential +async def test_init_uses_provided_cosmos_client(mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="checkpoints", + ) + assert storage.database_name == "db1" + assert storage.container_name == "checkpoints" -class TestCosmosCheckpointStorageContainerConfig: - async def test_container_name_is_used(self, mock_cosmos_client: MagicMock) -> None: - storage = CosmosCheckpointStorage( - cosmos_client=mock_cosmos_client, - database_name="db1", - container_name="custom-checkpoints", - ) - await storage.list_checkpoint_ids(workflow_name="wf") +async def test_init_missing_required_settings_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("AZURE_COSMOS_ENDPOINT", raising=False) + monkeypatch.delenv("AZURE_COSMOS_DATABASE_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_CONTAINER_NAME", raising=False) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) - database_client = mock_cosmos_client.get_database_client.return_value - assert database_client.create_container_if_not_exists.await_count == 1 - kwargs = database_client.create_container_if_not_exists.await_args.kwargs - assert kwargs["id"] == "custom-checkpoints" + with pytest.raises(SettingNotFoundError, match="database_name"): + CosmosCheckpointStorage() -class TestCosmosCheckpointStorageSave: - async def test_save_upserts_document(self, mock_container: MagicMock) -> None: - storage = CosmosCheckpointStorage(container_client=mock_container) - checkpoint = _make_checkpoint() +async def test_init_constructs_client_with_credential( + monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock +) -> None: + """Uses key-based auth when a key string is provided, otherwise falls back to Azure credential (RBAC).""" + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + monkeypatch.delenv("AZURE_COSMOS_KEY", raising=False) - result = await storage.save(checkpoint) + # Simulate real-world pattern: use key if available, else RBAC credential + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + credential: Any = cosmos_key if cosmos_key else MagicMock() # MagicMock simulates DefaultAzureCredential() - assert result == checkpoint.checkpoint_id - mock_container.upsert_item.assert_awaited_once() - document = mock_container.upsert_item.await_args.kwargs["body"] - assert document["id"] == checkpoint.checkpoint_id - assert document["workflow_name"] == "test-workflow" - assert document["graph_signature_hash"] == "abc123" - assert document["state"]["counter"] == 42 + CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential=credential, + database_name="db1", + container_name="checkpoints", + ) - async def test_save_returns_checkpoint_id(self, mock_container: MagicMock) -> None: - storage = CosmosCheckpointStorage(container_client=mock_container) - checkpoint = _make_checkpoint(checkpoint_id="cp-123") + mock_factory.assert_called_once() + kwargs = mock_factory.call_args.kwargs + assert kwargs["url"] == "https://account.documents.azure.com:443/" + assert kwargs["credential"] is credential - result = await storage.save(checkpoint) - assert result == "cp-123" +async def test_init_creates_database_and_container(mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="custom-checkpoints", + ) + await storage.list_checkpoint_ids(workflow_name="wf") -class TestCosmosCheckpointStorageLoad: - async def test_load_returns_checkpoint(self, mock_container: MagicMock) -> None: - checkpoint = _make_checkpoint(checkpoint_id="cp-load") - doc = _checkpoint_to_cosmos_document(checkpoint) - mock_container.query_items.return_value = _to_async_iter([doc]) + mock_cosmos_client.create_database_if_not_exists.assert_awaited_once_with(id="db1") + database_client = mock_cosmos_client.create_database_if_not_exists.return_value + assert database_client.create_container_if_not_exists.await_count == 1 + kwargs = database_client.create_container_if_not_exists.await_args.kwargs + assert kwargs["id"] == "custom-checkpoints" - storage = CosmosCheckpointStorage(container_client=mock_container) - loaded = await storage.load("cp-load") - assert loaded.checkpoint_id == "cp-load" - assert loaded.workflow_name == "test-workflow" - assert loaded.graph_signature_hash == "abc123" - assert loaded.state["counter"] == 42 +# --- Tests for save --- - async def test_load_nonexistent_raises(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) - storage = CosmosCheckpointStorage(container_client=mock_container) +async def test_save_upserts_document(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + checkpoint = _make_checkpoint() - with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"): - await storage.load("nonexistent-id") + result = await storage.save(checkpoint) - async def test_load_uses_cross_partition_query(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) + assert result == checkpoint.checkpoint_id + mock_container.upsert_item.assert_awaited_once() + document = mock_container.upsert_item.await_args.kwargs["body"] + assert document["id"] == checkpoint.checkpoint_id + assert document["workflow_name"] == "test-workflow" + assert document["graph_signature_hash"] == "abc123" + assert document["state"]["counter"] == 42 - storage = CosmosCheckpointStorage(container_client=mock_container) - with suppress(WorkflowCheckpointException): - await storage.load("cp-id") - kwargs = mock_container.query_items.call_args.kwargs - assert kwargs["enable_cross_partition_query"] is True +async def test_save_returns_checkpoint_id(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + checkpoint = _make_checkpoint(checkpoint_id="cp-123") + result = await storage.save(checkpoint) -class TestCosmosCheckpointStorageListCheckpoints: - async def test_returns_checkpoints_for_workflow(self, mock_container: MagicMock) -> None: - cp1 = _make_checkpoint(checkpoint_id="cp-1", timestamp="2025-01-01T00:00:00+00:00") - cp2 = _make_checkpoint(checkpoint_id="cp-2", timestamp="2025-01-02T00:00:00+00:00") - mock_container.query_items.return_value = _to_async_iter([ - _checkpoint_to_cosmos_document(cp1), - _checkpoint_to_cosmos_document(cp2), - ]) + assert result == "cp-123" - storage = CosmosCheckpointStorage(container_client=mock_container) - results = await storage.list_checkpoints(workflow_name="test-workflow") - assert len(results) == 2 - assert results[0].checkpoint_id == "cp-1" - assert results[1].checkpoint_id == "cp-2" +# --- Tests for load --- - async def test_uses_partition_key(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) - storage = CosmosCheckpointStorage(container_client=mock_container) - await storage.list_checkpoints(workflow_name="my-workflow") +async def test_load_returns_checkpoint(mock_container: MagicMock) -> None: + checkpoint = _make_checkpoint(checkpoint_id="cp-load") + doc = _checkpoint_to_cosmos_document(checkpoint) + mock_container.query_items.return_value = _to_async_iter([doc]) - kwargs = mock_container.query_items.call_args.kwargs - assert kwargs["partition_key"] == "my-workflow" + storage = CosmosCheckpointStorage(container_client=mock_container) + loaded = await storage.load("cp-load") - async def test_empty_returns_empty(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) + assert loaded.checkpoint_id == "cp-load" + assert loaded.workflow_name == "test-workflow" + assert loaded.graph_signature_hash == "abc123" + assert loaded.state["counter"] == 42 - storage = CosmosCheckpointStorage(container_client=mock_container) - results = await storage.list_checkpoints(workflow_name="test-workflow") - assert results == [] +async def test_load_nonexistent_raises(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + storage = CosmosCheckpointStorage(container_client=mock_container) -class TestCosmosCheckpointStorageDelete: - async def test_delete_existing_returns_true(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([ - {"id": "cp-del", "workflow_name": "test-workflow"}, - ]) + with pytest.raises(WorkflowCheckpointException, match="No checkpoint found"): + await storage.load("nonexistent-id") - storage = CosmosCheckpointStorage(container_client=mock_container) - result = await storage.delete("cp-del") - assert result is True - mock_container.delete_item.assert_awaited_once_with( - item="cp-del", - partition_key="test-workflow", - ) +async def test_load_queries_without_partition_key(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + with suppress(WorkflowCheckpointException): + await storage.load("cp-id") - async def test_delete_nonexistent_returns_false(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) + kwargs = mock_container.query_items.call_args.kwargs + assert "partition_key" not in kwargs - storage = CosmosCheckpointStorage(container_client=mock_container) - result = await storage.delete("nonexistent") - assert result is False - mock_container.delete_item.assert_not_awaited() +# --- Tests for list_checkpoints --- - async def test_delete_cosmos_not_found_returns_false(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([ - {"id": "cp-del", "workflow_name": "test-workflow"}, - ]) - mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError) - storage = CosmosCheckpointStorage(container_client=mock_container) - result = await storage.delete("cp-del") +async def test_list_checkpoints_returns_checkpoints_for_workflow(mock_container: MagicMock) -> None: + cp1 = _make_checkpoint(checkpoint_id="cp-1", timestamp="2025-01-01T00:00:00+00:00") + cp2 = _make_checkpoint(checkpoint_id="cp-2", timestamp="2025-01-02T00:00:00+00:00") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp1), + _checkpoint_to_cosmos_document(cp2), + ]) - assert result is False + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + assert len(results) == 2 + assert results[0].checkpoint_id == "cp-1" + assert results[1].checkpoint_id == "cp-2" -class TestCosmosCheckpointStorageGetLatest: - async def test_returns_latest_checkpoint(self, mock_container: MagicMock) -> None: - cp = _make_checkpoint(checkpoint_id="cp-latest", timestamp="2025-06-01T00:00:00+00:00") - mock_container.query_items.return_value = _to_async_iter([ - _checkpoint_to_cosmos_document(cp), - ]) - storage = CosmosCheckpointStorage(container_client=mock_container) - result = await storage.get_latest(workflow_name="test-workflow") +async def test_list_checkpoints_uses_partition_key(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) - assert result is not None - assert result.checkpoint_id == "cp-latest" + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.list_checkpoints(workflow_name="my-workflow") - async def test_returns_none_when_empty(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) + kwargs = mock_container.query_items.call_args.kwargs + assert kwargs["partition_key"] == "my-workflow" - storage = CosmosCheckpointStorage(container_client=mock_container) - result = await storage.get_latest(workflow_name="test-workflow") - assert result is None +async def test_list_checkpoints_empty_returns_empty(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) - async def test_uses_order_by_desc_with_limit(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") - storage = CosmosCheckpointStorage(container_client=mock_container) - await storage.get_latest(workflow_name="test-workflow") + assert results == [] - kwargs = mock_container.query_items.call_args.kwargs - assert "ORDER BY c.timestamp DESC" in kwargs["query"] - assert "OFFSET 0 LIMIT 1" in kwargs["query"] +# --- Tests for delete --- -class TestCosmosCheckpointStorageListIds: - async def test_returns_checkpoint_ids(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([ - {"checkpoint_id": "cp-1"}, - {"checkpoint_id": "cp-2"}, - ]) - storage = CosmosCheckpointStorage(container_client=mock_container) - ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") +async def test_delete_existing_returns_true(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"id": "cp-del", "workflow_name": "test-workflow"}, + ]) - assert ids == ["cp-1", "cp-2"] + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("cp-del") - async def test_empty_returns_empty(self, mock_container: MagicMock) -> None: - mock_container.query_items.return_value = _to_async_iter([]) + assert result is True + mock_container.delete_item.assert_awaited_once_with( + item="cp-del", + partition_key="test-workflow", + ) - storage = CosmosCheckpointStorage(container_client=mock_container) - ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") - assert ids == [] +async def test_delete_nonexistent_returns_false(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("nonexistent") -class TestCosmosCheckpointStorageClose: - async def test_close_closes_owned_client( - self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock - ) -> None: - mock_factory = MagicMock(return_value=mock_cosmos_client) - monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + assert result is False + mock_container.delete_item.assert_not_awaited() - storage = CosmosCheckpointStorage( - endpoint="https://account.documents.azure.com:443/", - credential="key-123", - database_name="db1", - container_name="checkpoints", - ) - await storage.close() +async def test_delete_cosmos_not_found_returns_false(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"id": "cp-del", "workflow_name": "test-workflow"}, + ]) + mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError) - mock_cosmos_client.close.assert_awaited_once() + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.delete("cp-del") + + assert result is False + + +# --- Tests for get_latest --- + + +async def test_get_latest_returns_latest_checkpoint(mock_container: MagicMock) -> None: + cp = _make_checkpoint(checkpoint_id="cp-latest", timestamp="2025-06-01T00:00:00+00:00") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is not None + assert result.checkpoint_id == "cp-latest" + + +async def test_get_latest_returns_none_when_empty(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + result = await storage.get_latest(workflow_name="test-workflow") + + assert result is None + + +async def test_get_latest_uses_order_by_desc_with_limit(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.get_latest(workflow_name="test-workflow") + + kwargs = mock_container.query_items.call_args.kwargs + assert "ORDER BY c.timestamp DESC" in kwargs["query"] + assert "OFFSET 0 LIMIT 1" in kwargs["query"] + + +# --- Tests for list_checkpoint_ids --- + + +async def test_list_checkpoint_ids_returns_ids(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([ + {"checkpoint_id": "cp-1"}, + {"checkpoint_id": "cp-2"}, + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") + + assert ids == ["cp-1", "cp-2"] + + +async def test_list_checkpoint_ids_empty_returns_empty(mock_container: MagicMock) -> None: + mock_container.query_items.return_value = _to_async_iter([]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + ids = await storage.list_checkpoint_ids(workflow_name="test-workflow") + + assert ids == [] + + +# --- Tests for close and context manager --- + + +async def test_close_closes_owned_client( + monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock +) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + + storage = CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="checkpoints", + ) + + await storage.close() + + mock_cosmos_client.close.assert_awaited_once() + + +async def test_close_does_not_close_external_client(mock_cosmos_client: MagicMock) -> None: + storage = CosmosCheckpointStorage( + cosmos_client=mock_cosmos_client, + database_name="db1", + container_name="checkpoints", + ) + + await storage.close() + + mock_cosmos_client.close.assert_not_awaited() + + +async def test_context_manager_closes_owned_client( + monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock +) -> None: + mock_factory = MagicMock(return_value=mock_cosmos_client) + monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) + + async with CosmosCheckpointStorage( + endpoint="https://account.documents.azure.com:443/", + credential="key-123", + database_name="db1", + container_name="checkpoints", + ) as storage: + assert storage is not None + + mock_cosmos_client.close.assert_awaited_once() + + +async def test_context_manager_preserves_original_exception(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + + with ( + patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))), + pytest.raises(ValueError, match="inner error"), + ): + async with storage: + raise ValueError("inner error") + + +# --- Tests for save/load round-trip --- + + +async def test_round_trip_preserves_data(mock_container: MagicMock) -> None: + checkpoint = _make_checkpoint( + checkpoint_id="cp-roundtrip", + previous_checkpoint_id="cp-parent", + ) + checkpoint.state = {"key": "value", "nested": {"a": 1}} + checkpoint.metadata = {"superstep": 3} + checkpoint.iteration_count = 5 + + saved_doc: dict[str, Any] = {} + + async def capture_upsert(body: dict[str, Any]) -> dict[str, Any]: + saved_doc.update(body) + return body + + mock_container.upsert_item = AsyncMock(side_effect=capture_upsert) + + storage = CosmosCheckpointStorage(container_client=mock_container) + await storage.save(checkpoint) + + returned_doc = { + **saved_doc, + "_rid": "abc", + "_self": "dbs/abc/colls/def/docs/ghi", + "_etag": '"etag"', + "_attachments": "attachments/", + "_ts": 1700000000, + } + mock_container.query_items.return_value = _to_async_iter([returned_doc]) + + loaded = await storage.load("cp-roundtrip") + + assert loaded.checkpoint_id == checkpoint.checkpoint_id + assert loaded.workflow_name == checkpoint.workflow_name + assert loaded.graph_signature_hash == checkpoint.graph_signature_hash + assert loaded.previous_checkpoint_id == "cp-parent" + assert loaded.state == {"key": "value", "nested": {"a": 1}} + assert loaded.metadata == {"superstep": 3} + assert loaded.iteration_count == 5 + assert loaded.version == "1.0" + + +# --- Integration test --- - async def test_close_does_not_close_external_client(self, mock_cosmos_client: MagicMock) -> None: - storage = CosmosCheckpointStorage( - cosmos_client=mock_cosmos_client, - database_name="db1", - container_name="checkpoints", - ) - await storage.close() - - mock_cosmos_client.close.assert_not_awaited() - - async def test_async_context_manager_closes_owned_client( - self, monkeypatch: pytest.MonkeyPatch, mock_cosmos_client: MagicMock - ) -> None: - mock_factory = MagicMock(return_value=mock_cosmos_client) - monkeypatch.setattr(checkpoint_storage_module, "CosmosClient", mock_factory) - - async with CosmosCheckpointStorage( - endpoint="https://account.documents.azure.com:443/", - credential="key-123", - database_name="db1", - container_name="checkpoints", - ) as storage: - assert storage is not None - - mock_cosmos_client.close.assert_awaited_once() - - async def test_async_context_manager_preserves_original_exception( - self, mock_container: MagicMock - ) -> None: - storage = CosmosCheckpointStorage(container_client=mock_container) - - with ( - patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))), - pytest.raises(ValueError, match="inner error"), - ): - async with storage: - raise ValueError("inner error") - - -class TestCosmosCheckpointStorageSaveLoadRoundTrip: - async def test_round_trip_preserves_data(self, mock_container: MagicMock) -> None: - """Test that saving and loading a checkpoint preserves all data.""" - checkpoint = _make_checkpoint( - checkpoint_id="cp-roundtrip", - previous_checkpoint_id="cp-parent", - ) - checkpoint.state = {"key": "value", "nested": {"a": 1}} - checkpoint.metadata = {"superstep": 3} - checkpoint.iteration_count = 5 - - # Capture the document that was saved - saved_doc: dict[str, Any] = {} - - async def capture_upsert(body: dict[str, Any]) -> dict[str, Any]: - saved_doc.update(body) - return body - - mock_container.upsert_item = AsyncMock(side_effect=capture_upsert) - - storage = CosmosCheckpointStorage(container_client=mock_container) - await storage.save(checkpoint) - - # Simulate Cosmos returning the saved document with system properties - returned_doc = { - **saved_doc, - "_rid": "abc", - "_self": "dbs/abc/colls/def/docs/ghi", - "_etag": '"etag"', - "_attachments": "attachments/", - "_ts": 1700000000, - } - mock_container.query_items.return_value = _to_async_iter([returned_doc]) - - loaded = await storage.load("cp-roundtrip") - - assert loaded.checkpoint_id == checkpoint.checkpoint_id - assert loaded.workflow_name == checkpoint.workflow_name - assert loaded.graph_signature_hash == checkpoint.graph_signature_hash - assert loaded.previous_checkpoint_id == "cp-parent" - assert loaded.state == {"key": "value", "nested": {"a": 1}} - assert loaded.metadata == {"superstep": 3} - assert loaded.iteration_count == 5 - assert loaded.version == "1.0" - - -@pytest.mark.flaky @pytest.mark.integration @skip_if_cosmos_integration_tests_disabled async def test_cosmos_checkpoint_storage_roundtrip_with_emulator() -> None: diff --git a/python/samples/02-agents/conversations/workflow_checkpointing.py b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py similarity index 87% rename from python/samples/02-agents/conversations/workflow_checkpointing.py rename to python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py index 7285610a5b..cc8872f110 100644 --- a/python/samples/02-agents/conversations/workflow_checkpointing.py +++ b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py @@ -149,20 +149,16 @@ async def main() -> None: .add_edge(worker, worker) ) - # --- First run: execute and stop after 3 iterations --- + # --- First run: execute the workflow --- print("\n=== First Run ===\n") workflow = workflow_builder.build() - event_stream = workflow.run(message=8, stream=True) - async for event in event_stream: + output = None + async for event in workflow.run(message=8, stream=True): if event.type == "output": - print(f"\nWorkflow completed: {event.data}") - break - if event.type == "superstep_completed": - print(f" [superstep completed, iteration {event.data}]") - if event.data >= 3: - print("\n** Stopping after 3 iterations **") - break + output = event.data + + print(f"Factor pairs computed: {output}") # List checkpoints saved in Cosmos DB checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( @@ -188,17 +184,16 @@ async def main() -> None: # --- Second run: resume from the latest checkpoint --- print("\n=== Resuming from Checkpoint ===\n") workflow2 = workflow_builder.build() - event_stream2 = workflow2.run( - checkpoint_id=latest.checkpoint_id, - stream=True, - ) - async for event in event_stream2: + output2 = None + async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): if event.type == "output": - print(f"\nWorkflow completed: {event.data}") - break - if event.type == "superstep_completed": - print(f" [superstep completed, iteration {event.data}]") + output2 = event.data + + if output2: + print(f"Resumed workflow produced: {output2}") + else: + print("Resumed workflow completed (no remaining work — already finished).") if __name__ == "__main__": diff --git a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py new file mode 100644 index 0000000000..6d20a88ac5 --- /dev/null +++ b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +""" +Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry + +Purpose: +This sample demonstrates how to use CosmosCheckpointStorage with agents built +on Azure AI Foundry (via AzureOpenAIResponsesClient). It shows a multi-agent +workflow where checkpoint state is persisted to Cosmos DB, enabling durable +pause-and-resume across process restarts. + +What you learn: +- How to wire CosmosCheckpointStorage with AzureOpenAIResponsesClient agents +- How to combine session history with workflow checkpointing +- How to resume a workflow-as-agent from a Cosmos DB checkpoint + +Key concepts: +- AgentSession: Maintains conversation history across agent invocations +- CosmosCheckpointStorage: Persists workflow execution state in Cosmos DB +- These are complementary: sessions track conversation, checkpoints track workflow state + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint + AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name + AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +from typing import Any + +from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +load_dotenv() + + +async def main() -> None: + """Run the Azure AI Foundry + Cosmos DB checkpointing sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT",) + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not project_endpoint or not deployment_name: + print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") + return + + if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: + print( + "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " + "and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # Authentication: use key if available, otherwise fall back to Azure credential (RBAC) + credential: Any + if cosmos_key: + credential = cosmos_key + else: + credential = AzureCliCredential() + + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + # Create Azure AI Foundry agents + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=AzureCliCredential(), + ) + + assistant = client.as_agent( + name="assistant", + instructions="You are a helpful assistant. Keep responses brief.", + ) + + reviewer = client.as_agent( + name="reviewer", + instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", + ) + + # Build a sequential workflow and wrap it as an agent + workflow = SequentialBuilder(participants=[assistant, reviewer]).build() + agent = workflow.as_agent(name="FoundryCheckpointedAgent") + + # --- First run: execute with Cosmos DB checkpointing --- + print("=== First Run ===\n") + + session = agent.create_session() + query = "What are the benefits of renewable energy?" + print(f"User: {query}") + + response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage) + + for msg in response.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show checkpoints persisted in Cosmos DB + checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}") + for i, cp in enumerate(checkpoints[:5], 1): + print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})") + + # --- Second run: continue conversation with checkpoint history --- + print("\n=== Second Run (continuing conversation) ===\n") + + query2 = "Can you elaborate on the economic benefits?" + print(f"User: {query2}") + + response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) + + for msg in response2.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show total checkpoints + all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}") + + # Get latest checkpoint + latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) + if latest: + print(f"Latest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") + + +if __name__ == "__main__": + asyncio.run(main()) From bc32a10d28059dcaa8dfb1b57f8e718ceae1787c Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Thu, 26 Mar 2026 15:15:21 -0700 Subject: [PATCH 3/8] Resolving comments --- .../_checkpoint_storage.py | 42 +++-- .../tests/test_cosmos_checkpoint_storage.py | 22 ++- .../cosmos_workflow_checkpointing.py | 145 +++++++++--------- .../cosmos_workflow_checkpointing_foundry.py | 59 ++++--- 4 files changed, 152 insertions(+), 116 deletions(-) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index ce608c8a52..3bee671e33 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -165,6 +165,9 @@ async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID: non-JSON-native values) and stored as a Cosmos DB document with the ``workflow_name`` as the partition key. + The document ``id`` is a composite of ``workflow_name`` and + ``checkpoint_id`` to ensure global uniqueness across partitions. + Args: checkpoint: The WorkflowCheckpoint object to save. @@ -177,7 +180,7 @@ async def save(self, checkpoint: WorkflowCheckpoint) -> CheckpointID: encoded = encode_checkpoint_value(checkpoint_dict) document: dict[str, Any] = { - "id": checkpoint.checkpoint_id, + "id": self._make_document_id(checkpoint.workflow_name, checkpoint.checkpoint_id), "workflow_name": checkpoint.workflow_name, **encoded, } @@ -196,11 +199,12 @@ async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint: The WorkflowCheckpoint object corresponding to the given ID. Raises: - WorkflowCheckpointException: If no checkpoint with the given ID exists. + WorkflowCheckpointException: If no checkpoint with the given ID exists, + or if multiple checkpoints share the same ID across workflows. """ await self._ensure_container_proxy() - query = "SELECT * FROM c WHERE c.id = @checkpoint_id" + query = "SELECT * FROM c WHERE c.checkpoint_id = @checkpoint_id" parameters: list[dict[str, object]] = [ {"name": "@checkpoint_id", "value": checkpoint_id}, ] @@ -210,10 +214,22 @@ async def load(self, checkpoint_id: CheckpointID) -> WorkflowCheckpoint: parameters=parameters, ) + results: list[dict[str, Any]] = [] async for item in items: - return self._document_to_checkpoint(item) + results.append(item) + + if not results: + raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}") + + if len(results) > 1: + workflow_names = [r.get("workflow_name", "unknown") for r in results] + raise WorkflowCheckpointException( + f"Multiple checkpoints found with ID {checkpoint_id} across workflows: " + f"{workflow_names}. Use list_checkpoints(workflow_name=...) to query " + f"by workflow instead." + ) - raise WorkflowCheckpointException(f"No checkpoint found with ID {checkpoint_id}") + return self._document_to_checkpoint(results[0]) async def list_checkpoints(self, *, workflow_name: str) -> list[WorkflowCheckpoint]: """List checkpoint objects for a given workflow name. @@ -256,8 +272,7 @@ async def delete(self, checkpoint_id: CheckpointID) -> bool: """ await self._ensure_container_proxy() - # We need to find the document first to get its partition key - query = "SELECT c.id, c.workflow_name FROM c WHERE c.id = @checkpoint_id" + query = "SELECT c.id, c.workflow_name FROM c WHERE c.checkpoint_id = @checkpoint_id" parameters: list[dict[str, object]] = [ {"name": "@checkpoint_id", "value": checkpoint_id}, ] @@ -270,7 +285,7 @@ async def delete(self, checkpoint_id: CheckpointID) -> bool: async for item in items: try: await self._container_proxy.delete_item( # type: ignore[union-attr] - item=checkpoint_id, + item=item["id"], partition_key=item["workflow_name"], ) logger.info("Deleted checkpoint %s from Cosmos DB", checkpoint_id) @@ -390,10 +405,19 @@ def _document_to_checkpoint(document: dict[str, Any]) -> WorkflowCheckpoint: Strips Cosmos DB system properties (``_rid``, ``_self``, ``_etag``, ``_attachments``, ``_ts``) before decoding. """ - # Remove Cosmos DB system properties and the 'id' field + # Remove Cosmos DB system properties and the composite 'id' field # (checkpoints use 'checkpoint_id', not 'id') cosmos_keys = {"id", "_rid", "_self", "_etag", "_attachments", "_ts"} cleaned = {k: v for k, v in document.items() if k not in cosmos_keys} decoded = decode_checkpoint_value(cleaned) return WorkflowCheckpoint.from_dict(decoded) + + @staticmethod + def _make_document_id(workflow_name: str, checkpoint_id: str) -> str: + """Create a composite Cosmos DB document ID. + + Combines ``workflow_name`` and ``checkpoint_id`` to ensure global + uniqueness across partitions. + """ + return f"{workflow_name}_{checkpoint_id}" diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py index 7734c3060a..4249214a45 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -181,7 +181,7 @@ async def test_save_upserts_document(mock_container: MagicMock) -> None: assert result == checkpoint.checkpoint_id mock_container.upsert_item.assert_awaited_once() document = mock_container.upsert_item.await_args.kwargs["body"] - assert document["id"] == checkpoint.checkpoint_id + assert document["id"] == f"test-workflow_{checkpoint.checkpoint_id}" assert document["workflow_name"] == "test-workflow" assert document["graph_signature_hash"] == "abc123" assert document["state"]["counter"] == 42 @@ -233,6 +233,20 @@ async def test_load_queries_without_partition_key(mock_container: MagicMock) -> assert "partition_key" not in kwargs +async def test_load_multiple_workflows_same_checkpoint_id_raises(mock_container: MagicMock) -> None: + cp1 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-a") + cp2 = _make_checkpoint(checkpoint_id="shared-id", workflow_name="workflow-b") + mock_container.query_items.return_value = _to_async_iter([ + _checkpoint_to_cosmos_document(cp1), + _checkpoint_to_cosmos_document(cp2), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + + with pytest.raises(WorkflowCheckpointException, match="Multiple checkpoints found"): + await storage.load("shared-id") + + # --- Tests for list_checkpoints --- @@ -276,7 +290,7 @@ async def test_list_checkpoints_empty_returns_empty(mock_container: MagicMock) - async def test_delete_existing_returns_true(mock_container: MagicMock) -> None: mock_container.query_items.return_value = _to_async_iter([ - {"id": "cp-del", "workflow_name": "test-workflow"}, + {"id": "test-workflow_cp-del", "workflow_name": "test-workflow"}, ]) storage = CosmosCheckpointStorage(container_client=mock_container) @@ -284,7 +298,7 @@ async def test_delete_existing_returns_true(mock_container: MagicMock) -> None: assert result is True mock_container.delete_item.assert_awaited_once_with( - item="cp-del", + item="test-workflow_cp-del", partition_key="test-workflow", ) @@ -301,7 +315,7 @@ async def test_delete_nonexistent_returns_false(mock_container: MagicMock) -> No async def test_delete_cosmos_not_found_returns_false(mock_container: MagicMock) -> None: mock_container.query_items.return_value = _to_async_iter([ - {"id": "cp-del", "workflow_name": "test-workflow"}, + {"id": "test-workflow_cp-del", "workflow_name": "test-workflow"}, ]) mock_container.delete_item = AsyncMock(side_effect=CosmosResourceNotFoundError) diff --git a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py index cc8872f110..4726742ffc 100644 --- a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py +++ b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. # ruff: noqa: T201 -""" -Sample: Workflow Checkpointing with Cosmos DB NoSQL +"""Sample: Workflow Checkpointing with Cosmos DB NoSQL. Purpose: This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint @@ -61,6 +60,7 @@ class StartExecutor(Executor): @handler async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None: + """Start the workflow with numbers up to the given limit.""" print(f"StartExecutor: Starting computation up to {upper_limit}") await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1)))) @@ -69,6 +69,7 @@ class WorkerExecutor(Executor): """Processes numbers and manages executor state for checkpointing.""" def __init__(self, id: str) -> None: + """Initialize the worker executor.""" super().__init__(id=id) self._results: dict[int, list[tuple[int, int]]] = {} @@ -78,6 +79,7 @@ async def compute( task: ComputeTask, ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]], ) -> None: + """Process the next number, computing its factor pairs.""" next_number = task.remaining_numbers.pop(0) print(f"WorkerExecutor: Processing {next_number}") @@ -116,84 +118,83 @@ async def main() -> None: return # Authentication: supports both managed identity/RBAC and key-based auth. - # - # Option 1 — Managed identity / RBAC (recommended for production): - # from azure.identity.aio import DefaultAzureCredential - # credential = DefaultAzureCredential() - # - # Option 2 — Account key: - # credential = cosmos_key (or set AZURE_COSMOS_KEY env var) - # - # This sample uses key-based auth when AZURE_COSMOS_KEY is set, - # otherwise falls back to DefaultAzureCredential. - credential: Any + # When AZURE_COSMOS_KEY is set, key-based auth is used. + # Otherwise, falls back to DefaultAzureCredential (properly closed via async with). if cosmos_key: - credential = cosmos_key + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_key, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + await _run_workflow(checkpoint_storage) else: from azure.identity.aio import DefaultAzureCredential - credential = DefaultAzureCredential() - - async with CosmosCheckpointStorage( - endpoint=cosmos_endpoint, - credential=credential, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - ) as checkpoint_storage: - # Build workflow with Cosmos DB checkpointing - start = StartExecutor(id="start") - worker = WorkerExecutor(id="worker") - workflow_builder = ( - WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) - .add_edge(start, worker) - .add_edge(worker, worker) - ) - - # --- First run: execute the workflow --- - print("\n=== First Run ===\n") - workflow = workflow_builder.build() - - output = None - async for event in workflow.run(message=8, stream=True): - if event.type == "output": - output = event.data - - print(f"Factor pairs computed: {output}") - - # List checkpoints saved in Cosmos DB - checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( - workflow_name=workflow.name, - ) - print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") - for cid in checkpoint_ids: - print(f" - {cid}") - - # Get the latest checkpoint - latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( - workflow_name=workflow.name, - ) - - if latest is None: - print("No checkpoint found to resume from.") - return + async with DefaultAzureCredential() as credential, CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + await _run_workflow(checkpoint_storage) + + +async def _run_workflow(checkpoint_storage: CosmosCheckpointStorage) -> None: + """Build and run the workflow with Cosmos DB checkpointing.""" + start = StartExecutor(id="start") + worker = WorkerExecutor(id="worker") + workflow_builder = ( + WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) + .add_edge(start, worker) + .add_edge(worker, worker) + ) + + # --- First run: execute the workflow --- + print("\n=== First Run ===\n") + workflow = workflow_builder.build() + + output = None + async for event in workflow.run(message=8, stream=True): + if event.type == "output": + output = event.data + + print(f"Factor pairs computed: {output}") + + # List checkpoints saved in Cosmos DB + checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( + workflow_name=workflow.name, + ) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") + for cid in checkpoint_ids: + print(f" - {cid}") + + # Get the latest checkpoint + latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( + workflow_name=workflow.name, + ) + + if latest is None: + print("No checkpoint found to resume from.") + return - print(f"\nLatest checkpoint: {latest.checkpoint_id}") - print(f" iteration_count: {latest.iteration_count}") - print(f" timestamp: {latest.timestamp}") + print(f"\nLatest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") - # --- Second run: resume from the latest checkpoint --- - print("\n=== Resuming from Checkpoint ===\n") - workflow2 = workflow_builder.build() + # --- Second run: resume from the latest checkpoint --- + print("\n=== Resuming from Checkpoint ===\n") + workflow2 = workflow_builder.build() - output2 = None - async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): - if event.type == "output": - output2 = event.data + output2 = None + async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): + if event.type == "output": + output2 = event.data - if output2: - print(f"Resumed workflow produced: {output2}") - else: - print("Resumed workflow completed (no remaining work — already finished).") + if output2: + print(f"Resumed workflow produced: {output2}") + else: + print("Resumed workflow completed (no remaining work — already finished).") if __name__ == "__main__": diff --git a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py index 6d20a88ac5..7c19d114a0 100644 --- a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py +++ b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py @@ -1,8 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. # ruff: noqa: T201 -""" -Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry +"""Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry. Purpose: This sample demonstrates how to use CosmosCheckpointStorage with agents built @@ -64,35 +63,33 @@ async def main() -> None: ) return - # Authentication: use key if available, otherwise fall back to Azure credential (RBAC) - credential: Any - if cosmos_key: - credential = cosmos_key - else: - credential = AzureCliCredential() - - async with CosmosCheckpointStorage( - endpoint=cosmos_endpoint, - credential=credential, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - ) as checkpoint_storage: - # Create Azure AI Foundry agents - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=AzureCliCredential(), - ) - - assistant = client.as_agent( - name="assistant", - instructions="You are a helpful assistant. Keep responses brief.", - ) - - reviewer = client.as_agent( - name="reviewer", - instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", - ) + # Use a single AzureCliCredential for both Cosmos and Foundry, + # properly closed via async context manager. + async with AzureCliCredential() as azure_credential: + cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential + + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + # Create Azure AI Foundry agents + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=azure_credential, + ) + + assistant = client.as_agent( + name="assistant", + instructions="You are a helpful assistant. Keep responses brief.", + ) + + reviewer = client.as_agent( + name="reviewer", + instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", + ) # Build a sequential workflow and wrap it as an agent workflow = SequentialBuilder(participants=[assistant, reviewer]).build() From 8338f192d0ec372e0db96730acca0ba0f1bfbcb9 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Thu, 26 Mar 2026 15:30:57 -0700 Subject: [PATCH 4/8] Fixing builds --- .../agent_framework_azure_cosmos/_checkpoint_storage.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index 3bee671e33..08db5d51b9 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -11,12 +11,15 @@ from agent_framework._settings import SecretString, load_settings from agent_framework._workflows._checkpoint import CheckpointID, WorkflowCheckpoint from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value -from agent_framework.azure._entra_id_authentication import AzureCredentialTypes from agent_framework.exceptions import WorkflowCheckpointException +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential from azure.cosmos import PartitionKey from azure.cosmos.aio import ContainerProxy, CosmosClient from azure.cosmos.exceptions import CosmosResourceNotFoundError +AzureCredentialTypes = TokenCredential | AsyncTokenCredential + logger = logging.getLogger(__name__) From b95bd167604bb8c1e0dd535695dad758d06a4ba9 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Thu, 26 Mar 2026 17:49:53 -0700 Subject: [PATCH 5/8] Adding sample for history provider and checkpoint storage --- python/packages/azure-cosmos/README.md | 9 +- .../azure-cosmos/samples/.env.template | 6 + .../packages/azure-cosmos/samples/README.md | 46 +++- .../samples/checkpoint_storage/__init__.py | 0 .../cosmos_checkpoint_foundry.py | 142 +++++++++++++ .../cosmos_checkpoint_workflow.py | 201 ++++++++++++++++++ .../samples/cosmos_e2e_foundry.py | 162 ++++++++++++++ .../samples/history_provider/__init__.py | 0 .../history_provider/cosmos_history_basic.py | 98 +++++++++ ...cosmos_history_conversation_persistence.py | 175 +++++++++++++++ .../cosmos_history_messages.py | 158 ++++++++++++++ .../cosmos_history_sessions.py | 198 +++++++++++++++++ 12 files changed, 1184 insertions(+), 11 deletions(-) create mode 100644 python/packages/azure-cosmos/samples/.env.template create mode 100644 python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py create mode 100644 python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py create mode 100644 python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py create mode 100644 python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py create mode 100644 python/packages/azure-cosmos/samples/history_provider/__init__.py create mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py create mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py create mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py create mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md index 3351f3c7c1..18bbd8dd17 100644 --- a/python/packages/azure-cosmos/README.md +++ b/python/packages/azure-cosmos/README.md @@ -35,7 +35,7 @@ Container naming behavior: - Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`) - `session_id` is used as the Cosmos partition key for reads/writes -See `samples/cosmos_history_provider.py` for a runnable package-local example. +See `samples/history_provider/cosmos_history_basic.py` for a runnable package-local example. ## Cosmos DB Workflow Checkpoint Storage @@ -121,6 +121,7 @@ portal with this partition key configuration. | `AZURE_COSMOS_CONTAINER_NAME` | Container name | | `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) | -See `samples/cosmos_workflow_checkpointing.py` for a standalone example, or -`samples/cosmos_workflow_checkpointing_foundry.py` for an end-to-end example -with Azure AI Foundry agents. +See `samples/checkpoint_storage/cosmos_checkpoint_workflow.py` for a standalone example, +`samples/checkpoint_storage/cosmos_checkpoint_foundry.py` for an end-to-end example +with Azure AI Foundry agents, or `samples/cosmos_e2e_foundry.py` for both +history and checkpointing together. diff --git a/python/packages/azure-cosmos/samples/.env.template b/python/packages/azure-cosmos/samples/.env.template new file mode 100644 index 0000000000..c200634623 --- /dev/null +++ b/python/packages/azure-cosmos/samples/.env.template @@ -0,0 +1,6 @@ +AZURE_AI_PROJECT_ENDPOINT +AZURE_AI_MODEL_DEPLOYMENT_NAME +AZURE_COSMOS_ENDPOINT +AZURE_COSMOS_KEY +AZURE_COSMOS_DATABASE_NAME +AZURE_COSMOS_CONTAINER_NAME diff --git a/python/packages/azure-cosmos/samples/README.md b/python/packages/azure-cosmos/samples/README.md index 9c767a56a1..4c448f159f 100644 --- a/python/packages/azure-cosmos/samples/README.md +++ b/python/packages/azure-cosmos/samples/README.md @@ -2,23 +2,55 @@ This folder contains samples for `agent-framework-azure-cosmos`. +## History Provider Samples + +Demonstrate conversation persistence using `CosmosHistoryProvider`. + +| File | Description | +| --- | --- | +| [`history_provider/cosmos_history_basic.py`](history_provider/cosmos_history_basic.py) | Basic multi-turn conversation using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient`, provider-configured container name, and `session_id` partitioning. | +| [`history_provider/cosmos_history_conversation_persistence.py`](history_provider/cosmos_history_conversation_persistence.py) | Persist and resume conversations across application restarts — serialize session state, create new provider/agent instances, and continue from Cosmos DB history. | +| [`history_provider/cosmos_history_messages.py`](history_provider/cosmos_history_messages.py) | Direct message history operations — retrieve stored messages as a transcript, clear session history, and verify data deletion. | +| [`history_provider/cosmos_history_sessions.py`](history_provider/cosmos_history_sessions.py) | Multi-session and multi-tenant management — per-tenant session isolation, `list_sessions()` to enumerate, switch between sessions, and resume specific conversations. | + +## Checkpoint Storage Samples + +Demonstrate workflow pause/resume using `CosmosCheckpointStorage`. + | File | Description | | --- | --- | -| [`cosmos_history_provider.py`](cosmos_history_provider.py) | Demonstrates an Agent using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient` (project endpoint), provider-configured container name, and `session_id` partitioning. | -| [`cosmos_workflow_checkpointing.py`](cosmos_workflow_checkpointing.py) | Workflow checkpoint storage with Cosmos DB — pause and resume workflows across restarts using `CosmosCheckpointStorage`, with support for key-based and managed identity auth. | -| [`cosmos_workflow_checkpointing_foundry.py`](cosmos_workflow_checkpointing_foundry.py) | End-to-end Azure AI Foundry + Cosmos DB checkpointing — multi-agent workflow using `AzureOpenAIResponsesClient` with `CosmosCheckpointStorage` for durable pause/resume. | +| [`checkpoint_storage/cosmos_checkpoint_workflow.py`](checkpoint_storage/cosmos_checkpoint_workflow.py) | Workflow checkpoint storage with Cosmos DB — pause and resume workflows across restarts using `CosmosCheckpointStorage`, with support for key-based and managed identity auth. | +| [`checkpoint_storage/cosmos_checkpoint_foundry.py`](checkpoint_storage/cosmos_checkpoint_foundry.py) | End-to-end Azure AI Foundry + Cosmos DB checkpointing — multi-agent workflow using `AzureOpenAIResponsesClient` with `CosmosCheckpointStorage` for durable pause/resume. | + +## Combined Sample + +| File | Description | +| --- | --- | +| [`cosmos_e2e_foundry.py`](cosmos_e2e_foundry.py) | Both `CosmosHistoryProvider` and `CosmosCheckpointStorage` in a single Azure AI Foundry agent app — the recommended production pattern for fully durable agent workflows. | ## Prerequisites - `AZURE_COSMOS_ENDPOINT` - `AZURE_COSMOS_DATABASE_NAME` -- `AZURE_COSMOS_CONTAINER_NAME` - `AZURE_COSMOS_KEY` (or equivalent credential flow) +For Foundry samples, also set: +- `AZURE_AI_PROJECT_ENDPOINT` +- `AZURE_AI_MODEL_DEPLOYMENT_NAME` + ## Run ```bash -uv run --directory packages/azure-cosmos python samples/cosmos_history_provider.py -uv run --directory packages/azure-cosmos python samples/cosmos_workflow_checkpointing.py -uv run --directory packages/azure-cosmos python samples/cosmos_workflow_checkpointing_foundry.py +# History provider samples +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_basic.py +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_conversation_persistence.py +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_messages.py +uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_sessions.py + +# Checkpoint storage samples +uv run --directory packages/azure-cosmos python samples/checkpoint_storage/cosmos_checkpoint_workflow.py +uv run --directory packages/azure-cosmos python samples/checkpoint_storage/cosmos_checkpoint_foundry.py + +# Combined sample +uv run --directory packages/azure-cosmos python samples/cosmos_e2e_foundry.py ``` diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py b/python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py new file mode 100644 index 0000000000..7c19d114a0 --- /dev/null +++ b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py @@ -0,0 +1,142 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +"""Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry. + +Purpose: +This sample demonstrates how to use CosmosCheckpointStorage with agents built +on Azure AI Foundry (via AzureOpenAIResponsesClient). It shows a multi-agent +workflow where checkpoint state is persisted to Cosmos DB, enabling durable +pause-and-resume across process restarts. + +What you learn: +- How to wire CosmosCheckpointStorage with AzureOpenAIResponsesClient agents +- How to combine session history with workflow checkpointing +- How to resume a workflow-as-agent from a Cosmos DB checkpoint + +Key concepts: +- AgentSession: Maintains conversation history across agent invocations +- CosmosCheckpointStorage: Persists workflow execution state in Cosmos DB +- These are complementary: sessions track conversation, checkpoints track workflow state + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint + AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name + AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +from typing import Any + +from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosCheckpointStorage + +load_dotenv() + + +async def main() -> None: + """Run the Azure AI Foundry + Cosmos DB checkpointing sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT",) + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not project_endpoint or not deployment_name: + print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") + return + + if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: + print( + "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " + "and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # Use a single AzureCliCredential for both Cosmos and Foundry, + # properly closed via async context manager. + async with AzureCliCredential() as azure_credential: + cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential + + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + # Create Azure AI Foundry agents + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=azure_credential, + ) + + assistant = client.as_agent( + name="assistant", + instructions="You are a helpful assistant. Keep responses brief.", + ) + + reviewer = client.as_agent( + name="reviewer", + instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", + ) + + # Build a sequential workflow and wrap it as an agent + workflow = SequentialBuilder(participants=[assistant, reviewer]).build() + agent = workflow.as_agent(name="FoundryCheckpointedAgent") + + # --- First run: execute with Cosmos DB checkpointing --- + print("=== First Run ===\n") + + session = agent.create_session() + query = "What are the benefits of renewable energy?" + print(f"User: {query}") + + response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage) + + for msg in response.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show checkpoints persisted in Cosmos DB + checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}") + for i, cp in enumerate(checkpoints[:5], 1): + print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})") + + # --- Second run: continue conversation with checkpoint history --- + print("\n=== Second Run (continuing conversation) ===\n") + + query2 = "Can you elaborate on the economic benefits?" + print(f"User: {query2}") + + response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) + + for msg in response2.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show total checkpoints + all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}") + + # Get latest checkpoint + latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) + if latest: + print(f"Latest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py new file mode 100644 index 0000000000..4726742ffc --- /dev/null +++ b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +"""Sample: Workflow Checkpointing with Cosmos DB NoSQL. + +Purpose: +This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint +storage backend for workflows, enabling durable pause-and-resume across +process restarts. + +What you learn: +- How to configure CosmosCheckpointStorage for workflow checkpointing +- How to run a workflow that automatically persists checkpoints to Cosmos DB +- How to resume a workflow from a Cosmos DB checkpoint +- How to list and inspect available checkpoints + +Prerequisites: +- An Azure Cosmos DB account (or local emulator) +- Environment variables set (see below) + +Environment variables: + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name + AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +import sys +from dataclasses import dataclass +from typing import Any + +from agent_framework import ( + Executor, + WorkflowBuilder, + WorkflowCheckpoint, + WorkflowContext, + handler, +) + +if sys.version_info >= (3, 12): + from typing import override # type: ignore # pragma: no cover +else: + from typing_extensions import override # type: ignore[import] # pragma: no cover + +from agent_framework_azure_cosmos import CosmosCheckpointStorage + + +@dataclass +class ComputeTask: + """Task containing the list of numbers remaining to be processed.""" + + remaining_numbers: list[int] + + +class StartExecutor(Executor): + """Initiates the workflow by providing the upper limit.""" + + @handler + async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None: + """Start the workflow with numbers up to the given limit.""" + print(f"StartExecutor: Starting computation up to {upper_limit}") + await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1)))) + + +class WorkerExecutor(Executor): + """Processes numbers and manages executor state for checkpointing.""" + + def __init__(self, id: str) -> None: + """Initialize the worker executor.""" + super().__init__(id=id) + self._results: dict[int, list[tuple[int, int]]] = {} + + @handler + async def compute( + self, + task: ComputeTask, + ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]], + ) -> None: + """Process the next number, computing its factor pairs.""" + next_number = task.remaining_numbers.pop(0) + print(f"WorkerExecutor: Processing {next_number}") + + pairs: list[tuple[int, int]] = [] + for i in range(1, next_number): + if next_number % i == 0: + pairs.append((i, next_number // i)) + self._results[next_number] = pairs + + if not task.remaining_numbers: + await ctx.yield_output(self._results) + else: + await ctx.send_message(task) + + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + return {"results": self._results} + + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + self._results = state.get("results", {}) + + +async def main() -> None: + """Run the workflow checkpointing sample with Cosmos DB.""" + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: + print( + "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " + "and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # Authentication: supports both managed identity/RBAC and key-based auth. + # When AZURE_COSMOS_KEY is set, key-based auth is used. + # Otherwise, falls back to DefaultAzureCredential (properly closed via async with). + if cosmos_key: + async with CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_key, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + await _run_workflow(checkpoint_storage) + else: + from azure.identity.aio import DefaultAzureCredential + + async with DefaultAzureCredential() as credential, CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=credential, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + ) as checkpoint_storage: + await _run_workflow(checkpoint_storage) + + +async def _run_workflow(checkpoint_storage: CosmosCheckpointStorage) -> None: + """Build and run the workflow with Cosmos DB checkpointing.""" + start = StartExecutor(id="start") + worker = WorkerExecutor(id="worker") + workflow_builder = ( + WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) + .add_edge(start, worker) + .add_edge(worker, worker) + ) + + # --- First run: execute the workflow --- + print("\n=== First Run ===\n") + workflow = workflow_builder.build() + + output = None + async for event in workflow.run(message=8, stream=True): + if event.type == "output": + output = event.data + + print(f"Factor pairs computed: {output}") + + # List checkpoints saved in Cosmos DB + checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( + workflow_name=workflow.name, + ) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") + for cid in checkpoint_ids: + print(f" - {cid}") + + # Get the latest checkpoint + latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( + workflow_name=workflow.name, + ) + + if latest is None: + print("No checkpoint found to resume from.") + return + + print(f"\nLatest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") + + # --- Second run: resume from the latest checkpoint --- + print("\n=== Resuming from Checkpoint ===\n") + workflow2 = workflow_builder.build() + + output2 = None + async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): + if event.type == "output": + output2 = event.data + + if output2: + print(f"Resumed workflow produced: {output2}") + else: + print("Resumed workflow completed (no remaining work — already finished).") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py b/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py new file mode 100644 index 0000000000..1e8a20c388 --- /dev/null +++ b/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py @@ -0,0 +1,162 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +"""Sample: Combined History Provider + Checkpoint Storage with Azure AI Foundry. + +Purpose: +This sample demonstrates using both CosmosHistoryProvider (conversation memory) +and CosmosCheckpointStorage (workflow pause/resume) together in a single Azure AI +Foundry agent application. This is the recommended production pattern for customers +who need both durable conversations and durable workflow execution. + +What you learn: +- How to wire both CosmosHistoryProvider and CosmosCheckpointStorage in one app +- How conversation history and workflow checkpoints serve complementary roles +- How to resume both conversation context and workflow execution state + +Key concepts: +- CosmosHistoryProvider: Persists conversation messages across sessions +- CosmosCheckpointStorage: Persists workflow execution state for pause/resume +- Together they enable fully durable agent workflows + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint + AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name + AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint + AZURE_COSMOS_DATABASE_NAME - Database name +Optional: + AZURE_COSMOS_KEY - Account key (if not using Azure credentials) +""" + +import asyncio +import os +from typing import Any + +from agent_framework import WorkflowAgent, WorkflowCheckpoint +from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework.orchestrations import SequentialBuilder +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosCheckpointStorage, CosmosHistoryProvider + +load_dotenv() + + +async def main() -> None: + """Run the combined history + checkpoint sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if not project_endpoint or not deployment_name: + print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") + return + + if not cosmos_endpoint or not cosmos_database_name: + print("Please set AZURE_COSMOS_ENDPOINT and AZURE_COSMOS_DATABASE_NAME.") + return + + async with AzureCliCredential() as azure_credential: + cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential + + # CosmosHistoryProvider: stores conversation messages + # CosmosCheckpointStorage: stores workflow execution state + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name="conversation-history", + credential=cosmos_credential, + ) as history_provider, + CosmosCheckpointStorage( + endpoint=cosmos_endpoint, + credential=cosmos_credential, + database_name=cosmos_database_name, + container_name="workflow-checkpoints", + ) as checkpoint_storage, + ): + # Create Azure AI Foundry agents + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=azure_credential, + ) + + assistant = client.as_agent( + name="assistant", + instructions="You are a helpful assistant. Keep responses brief.", + ) + + reviewer = client.as_agent( + name="reviewer", + instructions=( + "You are a reviewer. Provide a one-sentence " + "summary of the assistant's response." + ), + ) + + # Build a workflow with both history and checkpointing. + # Attach the history provider to the WorkflowAgent (outer agent) + # so conversation messages are persisted at the agent level. + workflow = SequentialBuilder( + participants=[assistant, reviewer], + ).build() + agent = WorkflowAgent( + workflow, + name="DurableAgent", + context_providers=[history_provider], + ) + + # --- First run --- + print("=== First Run ===\n") + session = agent.create_session() + + response = await agent.run( + "What are three benefits of cloud computing?", + session=session, + checkpoint_storage=checkpoint_storage, + ) + + for msg in response.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show what's persisted + checkpoints = await checkpoint_storage.list_checkpoints( + workflow_name=workflow.name, + ) + history = await history_provider.get_messages(session.session_id) + + print(f"\nConversation messages in Cosmos DB: {len(history)}") + print(f"Workflow checkpoints in Cosmos DB: {len(checkpoints)}") + + # --- Second run: conversation context is loaded from history --- + print("\n=== Second Run (with conversation context) ===\n") + + response2 = await agent.run( + "Can you elaborate on the first benefit?", + session=session, + checkpoint_storage=checkpoint_storage, + ) + + for msg in response2.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") + + # Show updated state + latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( + workflow_name=workflow.name, + ) + history2 = await history_provider.get_messages(session.session_id) + + print(f"\nConversation messages after 2 runs: {len(history2)}") + if latest: + print(f"Latest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/history_provider/__init__.py b/python/packages/azure-cosmos/samples/history_provider/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py new file mode 100644 index 0000000000..5dc5c54b65 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import os + +from agent_framework import Agent +from agent_framework.azure import CosmosHistoryProvider +from agent_framework.foundry import FoundryChatClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates CosmosHistoryProvider as an agent history provider. + +Key components: +- FoundryChatClient configured with an Azure AI project endpoint +- CosmosHistoryProvider configured for Cosmos DB-backed message history +- Provider-configured container name with session_id as partition key + +Environment variables: + FOUNDRY_PROJECT_ENDPOINT + FOUNDRY_MODEL + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the Cosmos history provider sample with an Agent.""" + project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT") + model = os.getenv("FOUNDRY_MODEL") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not model + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # 1. Create an Azure credential and a CosmosHistoryProvider for agent context + async with ( + AzureCliCredential() as credential, + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + # 2. Create an agent that uses Cosmos for persisted conversation history. + Agent( + client=FoundryChatClient( + project_endpoint=project_endpoint, + model=model, + credential=credential, + ), + name="CosmosHistoryAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 3. Create a session (session_id is used as the partition key). + session = agent.create_session() + + # 4. Run a multi-turn conversation; history is persisted by CosmosHistoryProvider. + response1 = await agent.run("My name is Ada and I enjoy distributed systems.", session=session) + print(f"Assistant: {response1.text}") + + response2 = await agent.run("What do you remember about me?", session=session) + print(f"Assistant: {response2.text}") + print(f"Container: {history_provider.container_name}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +Assistant: Nice to meet you, Ada! Distributed systems are a fascinating area. +Assistant: You told me your name is Ada and that you enjoy distributed systems. +Container: +""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py new file mode 100644 index 0000000000..f8e97e8d17 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework import AgentSession +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates persisting and resuming conversations across application +restarts using CosmosHistoryProvider as the persistent backend. + +Key components: +- Phase 1: Run a conversation and serialize the session with session.to_dict() +- Phase 2: Simulate an app restart — create new provider and agent instances, + restore the session with AgentSession.from_dict(), and continue the conversation +- Cosmos DB reloads the full message history, so the agent remembers everything + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_AI_MODEL_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the conversation persistence sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # ── Phase 1: Initial conversation ── + + print("=== Phase 1: Initial conversation ===\n") + + # 1. Create credential, client, history provider, and agent. + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="PersistentAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 2. Create a session and have a multi-turn conversation. + session = agent.create_session() + + response1 = await agent.run( + "My name is Ada. I'm building a distributed database in Rust.", session=session + ) + print("User: My name is Ada. I'm building a distributed database in Rust.") + print(f"Assistant: {response1.text}\n") + + response2 = await agent.run("The hardest part is the consensus algorithm.", session=session) + print("User: The hardest part is the consensus algorithm.") + print(f"Assistant: {response2.text}\n") + + # 3. Serialize the session state — this is what you'd persist to a database or file. + serialized_session = session.to_dict() + print(f"Session serialized. Session ID: {session.session_id}") + + # ── Phase 2: Simulate app restart ── + + print("\n=== Phase 2: Resuming after 'restart' ===\n") + + # 4. Create entirely new provider and agent instances (simulating a fresh process). + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="PersistentAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 5. Restore the session from the serialized state. + restored_session = AgentSession.from_dict(serialized_session) + print(f"Session restored. Session ID: {restored_session.session_id}\n") + + # 6. Continue the conversation — history is reloaded from Cosmos DB. + response3 = await agent.run("What was I working on and what was the challenge?", session=restored_session) + print("User: What was I working on and what was the challenge?") + print(f"Assistant: {response3.text}\n") + + # 7. Verify messages are in Cosmos by reading them directly. + messages = await history_provider.get_messages(restored_session.session_id) + print(f"Messages stored in Cosmos DB: {len(messages)}") + for i, msg in enumerate(messages, 1): + print(f" {i}. [{msg.role}] {msg.text[:80]}...") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Phase 1: Initial conversation === + +User: My name is Ada. I'm building a distributed database in Rust. +Assistant: That sounds like a great project, Ada! Rust is an excellent choice for ... + +User: The hardest part is the consensus algorithm. +Assistant: Consensus algorithms can be tricky! Are you looking at Raft, Paxos, or ... + +Session serialized. Session ID: + +=== Phase 2: Resuming after 'restart' === + +Session restored. Session ID: + +User: What was I working on and what was the challenge? +Assistant: You told me you're building a distributed database in Rust and that the hardest +part is the consensus algorithm. + +Messages stored in Cosmos DB: 6 + 1. [user] My name is Ada. I'm building a distributed database in Rust.... + 2. [assistant] That sounds like a great project, Ada! Rust is an excellent ch... + 3. [user] The hardest part is the consensus algorithm.... + 4. [assistant] Consensus algorithms can be tricky! Are you looking at Raft, Pa... + 5. [user] What was I working on and what was the challenge?... + 6. [assistant] You told me you're building a distributed database in Rust and ... +""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py new file mode 100644 index 0000000000..c97504fd40 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py @@ -0,0 +1,158 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates direct message history operations using +CosmosHistoryProvider — retrieving, displaying, and clearing stored messages. + +Key components: +- get_messages(session_id): Retrieve all stored messages as a chat transcript +- clear(session_id): Delete all messages for a session (e.g., GDPR compliance) +- Verifying that history is empty after clearing +- Running a new conversation in the same session after clearing + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_AI_MODEL_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the messages history sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="HistoryAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + session = agent.create_session() + session_id = session.session_id + + # 1. Have a multi-turn conversation. + print("=== Building a conversation ===\n") + + queries = [ + "Hi! My favorite programming language is Python.", + "I also enjoy hiking in the mountains on weekends.", + "What do you know about me so far?", + ] + for query in queries: + response = await agent.run(query, session=session) + print(f"User: {query}") + print(f"Assistant: {response.text}\n") + + # 2. Retrieve and display the full message history as a transcript. + print("=== Chat transcript from Cosmos DB ===\n") + + messages = await history_provider.get_messages(session_id) + print(f"Total messages stored: {len(messages)}\n") + for i, msg in enumerate(messages, 1): + print(f" {i}. [{msg.role}] {msg.text[:100]}") + + # 3. Clear the session history. + print("\n=== Clearing session history ===\n") + + await history_provider.clear(session_id) + print(f"Cleared all messages for session: {session_id}") + + # 4. Verify history is empty. + remaining = await history_provider.get_messages(session_id) + print(f"Messages after clear: {len(remaining)}") + + # 5. Start a fresh conversation in the same session — agent has no memory. + print("\n=== Fresh conversation (same session, no memory) ===\n") + + response = await agent.run("What do you know about me?", session=session) + print("User: What do you know about me?") + print(f"Assistant: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Building a conversation === + +User: Hi! My favorite programming language is Python. +Assistant: That's great! Python is a wonderful language. What do you like most about it? + +User: I also enjoy hiking in the mountains on weekends. +Assistant: Hiking sounds lovely! Do you have a favorite trail or mountain range? + +User: What do you know about me so far? +Assistant: You love Python as your favorite programming language and enjoy hiking in the mountains on weekends. + +=== Chat transcript from Cosmos DB === + +Total messages stored: 6 + + 1. [user] Hi! My favorite programming language is Python. + 2. [assistant] That's great! Python is a wonderful language. What do you like most about it? + 3. [user] I also enjoy hiking in the mountains on weekends. + 4. [assistant] Hiking sounds lovely! Do you have a favorite trail or mountain range? + 5. [user] What do you know about me so far? + 6. [assistant] You love Python as your favorite programming language and enjoy hiking ... + +=== Clearing session history === + +Cleared all messages for session: +Messages after clear: 0 + +=== Fresh conversation (same session, no memory) === + +User: What do you know about me? +Assistant: I don't have any information about you yet. Feel free to share anything you'd like! +""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py new file mode 100644 index 0000000000..6772d41825 --- /dev/null +++ b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py @@ -0,0 +1,198 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework.azure import AzureOpenAIResponsesClient +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework_azure_cosmos import CosmosHistoryProvider + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates multi-session and multi-tenant management using +CosmosHistoryProvider. Each tenant (user) gets isolated conversation sessions +stored in the same Cosmos DB container, partitioned by session_id. + +Key components: +- Per-tenant session isolation using prefixed session IDs +- list_sessions(): Enumerate all stored sessions across tenants +- Switching between sessions for different users +- Resuming a specific user's session — verifying data isolation + +Environment variables: + AZURE_AI_PROJECT_ENDPOINT + AZURE_AI_MODEL_DEPLOYMENT_NAME + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the session management sample.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") + deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not deployment_name + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + async with AzureCliCredential() as credential: + client = AzureOpenAIResponsesClient( + project_endpoint=project_endpoint, + deployment_name=deployment_name, + credential=credential, + ) + + async with ( + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + client.as_agent( + name="MultiTenantAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 1. Tenant "alice" starts a conversation about travel. + print("=== Tenant: Alice — Travel conversation ===\n") + + alice_session = agent.create_session(session_id="tenant-alice-session-1") + + response = await agent.run( + "Hi! I'm planning a trip to Italy. I love Renaissance art.", session=alice_session + ) + print("Alice: I'm planning a trip to Italy. I love Renaissance art.") + print(f"Assistant: {response.text}\n") + + response = await agent.run("Which museums should I visit in Florence?", session=alice_session) + print("Alice: Which museums should I visit in Florence?") + print(f"Assistant: {response.text}\n") + + # 2. Tenant "bob" starts a separate conversation about cooking. + print("=== Tenant: Bob — Cooking conversation ===\n") + + bob_session = agent.create_session(session_id="tenant-bob-session-1") + + response = await agent.run( + "Hey! I'm learning to cook Thai food. I just made pad thai.", session=bob_session + ) + print("Bob: I'm learning to cook Thai food. I just made pad thai.") + print(f"Assistant: {response.text}\n") + + response = await agent.run("What Thai dish should I try next?", session=bob_session) + print("Bob: What Thai dish should I try next?") + print(f"Assistant: {response.text}\n") + + # 3. List all sessions stored in Cosmos DB. + print("=== Listing all sessions ===\n") + + sessions = await history_provider.list_sessions() + print(f"Found {len(sessions)} session(s):") + for sid in sessions: + print(f" - {sid}") + + # 4. Resume Alice's session — verify she gets her travel context back. + print("\n=== Resuming Alice's session ===\n") + + alice_resumed = agent.create_session(session_id="tenant-alice-session-1") + + response = await agent.run("What were we discussing?", session=alice_resumed) + print("Alice: What were we discussing?") + print(f"Assistant: {response.text}\n") + + # 5. Resume Bob's session — verify he gets his cooking context back. + print("=== Resuming Bob's session ===\n") + + bob_resumed = agent.create_session(session_id="tenant-bob-session-1") + + response = await agent.run("What was the last dish I mentioned?", session=bob_resumed) + print("Bob: What was the last dish I mentioned?") + print(f"Assistant: {response.text}\n") + + # 6. Show per-session message counts. + print("=== Per-session message counts ===\n") + + alice_messages = await history_provider.get_messages("tenant-alice-session-1") + bob_messages = await history_provider.get_messages("tenant-bob-session-1") + print(f"Alice's session: {len(alice_messages)} messages") + print(f"Bob's session: {len(bob_messages)} messages") + + # 7. Clean up: clear both sessions. + print("\n=== Cleaning up ===\n") + + await history_provider.clear("tenant-alice-session-1") + await history_provider.clear("tenant-bob-session-1") + print("Cleared Alice's and Bob's sessions.") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Tenant: Alice — Travel conversation === + +Alice: I'm planning a trip to Italy. I love Renaissance art. +Assistant: Italy is a dream for Renaissance art lovers! Florence, Rome, and Venice ... + +Alice: Which museums should I visit in Florence? +Assistant: In Florence, the Uffizi Gallery is a must — it has Botticelli's Birth of Venus ... + +=== Tenant: Bob — Cooking conversation === + +Bob: I'm learning to cook Thai food. I just made pad thai. +Assistant: Pad thai is a great start! How did it turn out? + +Bob: What Thai dish should I try next? +Assistant: I'd suggest trying green curry or tom yum soup — both are classic Thai dishes ... + +=== Listing all sessions === + +Found 2 session(s): + - tenant-alice-session-1 + - tenant-bob-session-1 + +=== Resuming Alice's session === + +Alice: What were we discussing? +Assistant: We were discussing your trip to Italy and your love for Renaissance art ... + +=== Resuming Bob's session === + +Bob: What was the last dish I mentioned? +Assistant: You mentioned pad thai — it was the dish you just made! + +=== Per-session message counts === + +Alice's session: 6 messages +Bob's session: 6 messages + +=== Cleaning up === + +Cleared Alice's and Bob's sessions. +""" From 17b83290ba429f9f6a5aa514747a26c5a3b90bf6 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Sat, 28 Mar 2026 08:36:55 -0700 Subject: [PATCH 6/8] Resolving comments --- .../cosmos_checkpoint_foundry.py | 68 +++++++++---------- .../tests/test_cosmos_checkpoint_storage.py | 27 +++++++- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py index 7c19d114a0..37fd6c986a 100644 --- a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py +++ b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py @@ -91,51 +91,51 @@ async def main() -> None: instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", ) - # Build a sequential workflow and wrap it as an agent - workflow = SequentialBuilder(participants=[assistant, reviewer]).build() - agent = workflow.as_agent(name="FoundryCheckpointedAgent") + # Build a sequential workflow and wrap it as an agent + workflow = SequentialBuilder(participants=[assistant, reviewer]).build() + agent = workflow.as_agent(name="FoundryCheckpointedAgent") - # --- First run: execute with Cosmos DB checkpointing --- - print("=== First Run ===\n") + # --- First run: execute with Cosmos DB checkpointing --- + print("=== First Run ===\n") - session = agent.create_session() - query = "What are the benefits of renewable energy?" - print(f"User: {query}") + session = agent.create_session() + query = "What are the benefits of renewable energy?" + print(f"User: {query}") - response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage) + response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage) - for msg in response.messages: - speaker = msg.author_name or msg.role - print(f"[{speaker}]: {msg.text}") + for msg in response.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") - # Show checkpoints persisted in Cosmos DB - checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) - print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}") - for i, cp in enumerate(checkpoints[:5], 1): - print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})") + # Show checkpoints persisted in Cosmos DB + checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}") + for i, cp in enumerate(checkpoints[:5], 1): + print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})") - # --- Second run: continue conversation with checkpoint history --- - print("\n=== Second Run (continuing conversation) ===\n") + # --- Second run: continue conversation with checkpoint history --- + print("\n=== Second Run (continuing conversation) ===\n") - query2 = "Can you elaborate on the economic benefits?" - print(f"User: {query2}") + query2 = "Can you elaborate on the economic benefits?" + print(f"User: {query2}") - response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) + response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) - for msg in response2.messages: - speaker = msg.author_name or msg.role - print(f"[{speaker}]: {msg.text}") + for msg in response2.messages: + speaker = msg.author_name or msg.role + print(f"[{speaker}]: {msg.text}") - # Show total checkpoints - all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) - print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}") + # Show total checkpoints + all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) + print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}") - # Get latest checkpoint - latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) - if latest: - print(f"Latest checkpoint: {latest.checkpoint_id}") - print(f" iteration_count: {latest.iteration_count}") - print(f" timestamp: {latest.timestamp}") + # Get latest checkpoint + latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) + if latest: + print(f"Latest checkpoint: {latest.checkpoint_id}") + print(f" iteration_count: {latest.iteration_count}") + print(f" timestamp: {latest.timestamp}") if __name__ == "__main__": diff --git a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py index 4249214a45..5e183c3223 100644 --- a/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py +++ b/python/packages/azure-cosmos/tests/test_cosmos_checkpoint_storage.py @@ -66,7 +66,7 @@ def _checkpoint_to_cosmos_document(checkpoint: WorkflowCheckpoint) -> dict[str, """Simulate what a Cosmos DB document looks like after save.""" encoded = encode_checkpoint_value(checkpoint.to_dict()) doc: dict[str, Any] = { - "id": checkpoint.checkpoint_id, + "id": f"{checkpoint.workflow_name}_{checkpoint.checkpoint_id}", "workflow_name": checkpoint.workflow_name, **encoded, # Cosmos system properties @@ -285,6 +285,20 @@ async def test_list_checkpoints_empty_returns_empty(mock_container: MagicMock) - assert results == [] +async def test_list_checkpoints_skips_malformed_documents(mock_container: MagicMock) -> None: + valid_cp = _make_checkpoint(checkpoint_id="cp-valid") + mock_container.query_items.return_value = _to_async_iter([ + {"id": "bad_doc", "workflow_name": "test-workflow", "not_a_checkpoint": True}, + _checkpoint_to_cosmos_document(valid_cp), + ]) + + storage = CosmosCheckpointStorage(container_client=mock_container) + results = await storage.list_checkpoints(workflow_name="test-workflow") + + assert len(results) == 1 + assert results[0].checkpoint_id == "cp-valid" + + # --- Tests for delete --- @@ -446,6 +460,17 @@ async def test_context_manager_preserves_original_exception(mock_container: Magi raise ValueError("inner error") +async def test_context_manager_reraises_close_error(mock_container: MagicMock) -> None: + storage = CosmosCheckpointStorage(container_client=mock_container) + + with ( + patch.object(storage, "close", AsyncMock(side_effect=RuntimeError("close failed"))), + pytest.raises(RuntimeError, match="close failed"), + ): + async with storage: + pass # no inner exception — close error should propagate + + # --- Tests for save/load round-trip --- From e287d61e3ac8dbf04a251a11ef08de6c390c72e2 Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Thu, 2 Apr 2026 12:13:30 -0700 Subject: [PATCH 7/8] fixing builds --- python/packages/azure-cosmos/README.md | 9 +- .../azure-cosmos/samples/.env.template | 6 - .../packages/azure-cosmos/samples/README.md | 56 ----- .../samples/checkpoint_storage/__init__.py | 0 .../cosmos_checkpoint_workflow.py | 201 ------------------ .../samples/cosmos_e2e_foundry.py | 162 -------------- .../samples/history_provider/__init__.py | 0 .../history_provider/cosmos_history_basic.py | 98 --------- ...cosmos_history_conversation_persistence.py | 175 --------------- .../cosmos_history_messages.py | 158 -------------- .../cosmos_history_sessions.py | 198 ----------------- .../samples/02-agents/conversations/README.md | 5 +- ...story_provider_conversation_persistence.py | 165 ++++++++++++++ .../cosmos_history_provider_messages.py | 157 ++++++++++++++ .../cosmos_history_provider_sessions.py | 197 +++++++++++++++++ .../cosmos_workflow_checkpointing_foundry.py | 142 ------------- .../cosmos_workflow_checkpointing.py | 0 .../cosmos_workflow_checkpointing_foundry.py} | 32 +-- 18 files changed, 544 insertions(+), 1217 deletions(-) delete mode 100644 python/packages/azure-cosmos/samples/.env.template delete mode 100644 python/packages/azure-cosmos/samples/README.md delete mode 100644 python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py delete mode 100644 python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py delete mode 100644 python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py delete mode 100644 python/packages/azure-cosmos/samples/history_provider/__init__.py delete mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py delete mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py delete mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py delete mode 100644 python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py create mode 100644 python/samples/02-agents/conversations/cosmos_history_provider_conversation_persistence.py create mode 100644 python/samples/02-agents/conversations/cosmos_history_provider_messages.py create mode 100644 python/samples/02-agents/conversations/cosmos_history_provider_sessions.py delete mode 100644 python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py rename python/samples/{02-agents/conversations => 03-workflows/checkpoint}/cosmos_workflow_checkpointing.py (100%) rename python/{packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py => samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py} (86%) diff --git a/python/packages/azure-cosmos/README.md b/python/packages/azure-cosmos/README.md index 18bbd8dd17..a03c5c6f93 100644 --- a/python/packages/azure-cosmos/README.md +++ b/python/packages/azure-cosmos/README.md @@ -35,7 +35,7 @@ Container naming behavior: - Container name is configured on the provider (`container_name` or `AZURE_COSMOS_CONTAINER_NAME`) - `session_id` is used as the Cosmos partition key for reads/writes -See `samples/history_provider/cosmos_history_basic.py` for a runnable package-local example. +See `samples/02-agents/conversations/cosmos_history_provider.py` for a runnable example. ## Cosmos DB Workflow Checkpoint Storage @@ -121,7 +121,6 @@ portal with this partition key configuration. | `AZURE_COSMOS_CONTAINER_NAME` | Container name | | `AZURE_COSMOS_KEY` | Account key (optional if using Azure credentials) | -See `samples/checkpoint_storage/cosmos_checkpoint_workflow.py` for a standalone example, -`samples/checkpoint_storage/cosmos_checkpoint_foundry.py` for an end-to-end example -with Azure AI Foundry agents, or `samples/cosmos_e2e_foundry.py` for both -history and checkpointing together. +See `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing.py` for a standalone example, +or `samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py` for an end-to-end +example with Azure AI Foundry agents. diff --git a/python/packages/azure-cosmos/samples/.env.template b/python/packages/azure-cosmos/samples/.env.template deleted file mode 100644 index c200634623..0000000000 --- a/python/packages/azure-cosmos/samples/.env.template +++ /dev/null @@ -1,6 +0,0 @@ -AZURE_AI_PROJECT_ENDPOINT -AZURE_AI_MODEL_DEPLOYMENT_NAME -AZURE_COSMOS_ENDPOINT -AZURE_COSMOS_KEY -AZURE_COSMOS_DATABASE_NAME -AZURE_COSMOS_CONTAINER_NAME diff --git a/python/packages/azure-cosmos/samples/README.md b/python/packages/azure-cosmos/samples/README.md deleted file mode 100644 index 4c448f159f..0000000000 --- a/python/packages/azure-cosmos/samples/README.md +++ /dev/null @@ -1,56 +0,0 @@ -# Azure Cosmos DB Package Samples - -This folder contains samples for `agent-framework-azure-cosmos`. - -## History Provider Samples - -Demonstrate conversation persistence using `CosmosHistoryProvider`. - -| File | Description | -| --- | --- | -| [`history_provider/cosmos_history_basic.py`](history_provider/cosmos_history_basic.py) | Basic multi-turn conversation using `CosmosHistoryProvider` with `AzureOpenAIResponsesClient`, provider-configured container name, and `session_id` partitioning. | -| [`history_provider/cosmos_history_conversation_persistence.py`](history_provider/cosmos_history_conversation_persistence.py) | Persist and resume conversations across application restarts — serialize session state, create new provider/agent instances, and continue from Cosmos DB history. | -| [`history_provider/cosmos_history_messages.py`](history_provider/cosmos_history_messages.py) | Direct message history operations — retrieve stored messages as a transcript, clear session history, and verify data deletion. | -| [`history_provider/cosmos_history_sessions.py`](history_provider/cosmos_history_sessions.py) | Multi-session and multi-tenant management — per-tenant session isolation, `list_sessions()` to enumerate, switch between sessions, and resume specific conversations. | - -## Checkpoint Storage Samples - -Demonstrate workflow pause/resume using `CosmosCheckpointStorage`. - -| File | Description | -| --- | --- | -| [`checkpoint_storage/cosmos_checkpoint_workflow.py`](checkpoint_storage/cosmos_checkpoint_workflow.py) | Workflow checkpoint storage with Cosmos DB — pause and resume workflows across restarts using `CosmosCheckpointStorage`, with support for key-based and managed identity auth. | -| [`checkpoint_storage/cosmos_checkpoint_foundry.py`](checkpoint_storage/cosmos_checkpoint_foundry.py) | End-to-end Azure AI Foundry + Cosmos DB checkpointing — multi-agent workflow using `AzureOpenAIResponsesClient` with `CosmosCheckpointStorage` for durable pause/resume. | - -## Combined Sample - -| File | Description | -| --- | --- | -| [`cosmos_e2e_foundry.py`](cosmos_e2e_foundry.py) | Both `CosmosHistoryProvider` and `CosmosCheckpointStorage` in a single Azure AI Foundry agent app — the recommended production pattern for fully durable agent workflows. | - -## Prerequisites - -- `AZURE_COSMOS_ENDPOINT` -- `AZURE_COSMOS_DATABASE_NAME` -- `AZURE_COSMOS_KEY` (or equivalent credential flow) - -For Foundry samples, also set: -- `AZURE_AI_PROJECT_ENDPOINT` -- `AZURE_AI_MODEL_DEPLOYMENT_NAME` - -## Run - -```bash -# History provider samples -uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_basic.py -uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_conversation_persistence.py -uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_messages.py -uv run --directory packages/azure-cosmos python samples/history_provider/cosmos_history_sessions.py - -# Checkpoint storage samples -uv run --directory packages/azure-cosmos python samples/checkpoint_storage/cosmos_checkpoint_workflow.py -uv run --directory packages/azure-cosmos python samples/checkpoint_storage/cosmos_checkpoint_foundry.py - -# Combined sample -uv run --directory packages/azure-cosmos python samples/cosmos_e2e_foundry.py -``` diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py b/python/packages/azure-cosmos/samples/checkpoint_storage/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py b/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py deleted file mode 100644 index 4726742ffc..0000000000 --- a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_workflow.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# ruff: noqa: T201 - -"""Sample: Workflow Checkpointing with Cosmos DB NoSQL. - -Purpose: -This sample shows how to use Azure Cosmos DB NoSQL as a persistent checkpoint -storage backend for workflows, enabling durable pause-and-resume across -process restarts. - -What you learn: -- How to configure CosmosCheckpointStorage for workflow checkpointing -- How to run a workflow that automatically persists checkpoints to Cosmos DB -- How to resume a workflow from a Cosmos DB checkpoint -- How to list and inspect available checkpoints - -Prerequisites: -- An Azure Cosmos DB account (or local emulator) -- Environment variables set (see below) - -Environment variables: - AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint - AZURE_COSMOS_DATABASE_NAME - Database name - AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints -Optional: - AZURE_COSMOS_KEY - Account key (if not using Azure credentials) -""" - -import asyncio -import os -import sys -from dataclasses import dataclass -from typing import Any - -from agent_framework import ( - Executor, - WorkflowBuilder, - WorkflowCheckpoint, - WorkflowContext, - handler, -) - -if sys.version_info >= (3, 12): - from typing import override # type: ignore # pragma: no cover -else: - from typing_extensions import override # type: ignore[import] # pragma: no cover - -from agent_framework_azure_cosmos import CosmosCheckpointStorage - - -@dataclass -class ComputeTask: - """Task containing the list of numbers remaining to be processed.""" - - remaining_numbers: list[int] - - -class StartExecutor(Executor): - """Initiates the workflow by providing the upper limit.""" - - @handler - async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None: - """Start the workflow with numbers up to the given limit.""" - print(f"StartExecutor: Starting computation up to {upper_limit}") - await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1)))) - - -class WorkerExecutor(Executor): - """Processes numbers and manages executor state for checkpointing.""" - - def __init__(self, id: str) -> None: - """Initialize the worker executor.""" - super().__init__(id=id) - self._results: dict[int, list[tuple[int, int]]] = {} - - @handler - async def compute( - self, - task: ComputeTask, - ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]], - ) -> None: - """Process the next number, computing its factor pairs.""" - next_number = task.remaining_numbers.pop(0) - print(f"WorkerExecutor: Processing {next_number}") - - pairs: list[tuple[int, int]] = [] - for i in range(1, next_number): - if next_number % i == 0: - pairs.append((i, next_number // i)) - self._results[next_number] = pairs - - if not task.remaining_numbers: - await ctx.yield_output(self._results) - else: - await ctx.send_message(task) - - @override - async def on_checkpoint_save(self) -> dict[str, Any]: - return {"results": self._results} - - @override - async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - self._results = state.get("results", {}) - - -async def main() -> None: - """Run the workflow checkpointing sample with Cosmos DB.""" - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: - print( - "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " - "and AZURE_COSMOS_CONTAINER_NAME." - ) - return - - # Authentication: supports both managed identity/RBAC and key-based auth. - # When AZURE_COSMOS_KEY is set, key-based auth is used. - # Otherwise, falls back to DefaultAzureCredential (properly closed via async with). - if cosmos_key: - async with CosmosCheckpointStorage( - endpoint=cosmos_endpoint, - credential=cosmos_key, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - ) as checkpoint_storage: - await _run_workflow(checkpoint_storage) - else: - from azure.identity.aio import DefaultAzureCredential - - async with DefaultAzureCredential() as credential, CosmosCheckpointStorage( - endpoint=cosmos_endpoint, - credential=credential, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - ) as checkpoint_storage: - await _run_workflow(checkpoint_storage) - - -async def _run_workflow(checkpoint_storage: CosmosCheckpointStorage) -> None: - """Build and run the workflow with Cosmos DB checkpointing.""" - start = StartExecutor(id="start") - worker = WorkerExecutor(id="worker") - workflow_builder = ( - WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage) - .add_edge(start, worker) - .add_edge(worker, worker) - ) - - # --- First run: execute the workflow --- - print("\n=== First Run ===\n") - workflow = workflow_builder.build() - - output = None - async for event in workflow.run(message=8, stream=True): - if event.type == "output": - output = event.data - - print(f"Factor pairs computed: {output}") - - # List checkpoints saved in Cosmos DB - checkpoint_ids = await checkpoint_storage.list_checkpoint_ids( - workflow_name=workflow.name, - ) - print(f"\nCheckpoints in Cosmos DB: {len(checkpoint_ids)}") - for cid in checkpoint_ids: - print(f" - {cid}") - - # Get the latest checkpoint - latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( - workflow_name=workflow.name, - ) - - if latest is None: - print("No checkpoint found to resume from.") - return - - print(f"\nLatest checkpoint: {latest.checkpoint_id}") - print(f" iteration_count: {latest.iteration_count}") - print(f" timestamp: {latest.timestamp}") - - # --- Second run: resume from the latest checkpoint --- - print("\n=== Resuming from Checkpoint ===\n") - workflow2 = workflow_builder.build() - - output2 = None - async for event in workflow2.run(checkpoint_id=latest.checkpoint_id, stream=True): - if event.type == "output": - output2 = event.data - - if output2: - print(f"Resumed workflow produced: {output2}") - else: - print("Resumed workflow completed (no remaining work — already finished).") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py b/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py deleted file mode 100644 index 1e8a20c388..0000000000 --- a/python/packages/azure-cosmos/samples/cosmos_e2e_foundry.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# ruff: noqa: T201 - -"""Sample: Combined History Provider + Checkpoint Storage with Azure AI Foundry. - -Purpose: -This sample demonstrates using both CosmosHistoryProvider (conversation memory) -and CosmosCheckpointStorage (workflow pause/resume) together in a single Azure AI -Foundry agent application. This is the recommended production pattern for customers -who need both durable conversations and durable workflow execution. - -What you learn: -- How to wire both CosmosHistoryProvider and CosmosCheckpointStorage in one app -- How conversation history and workflow checkpoints serve complementary roles -- How to resume both conversation context and workflow execution state - -Key concepts: -- CosmosHistoryProvider: Persists conversation messages across sessions -- CosmosCheckpointStorage: Persists workflow execution state for pause/resume -- Together they enable fully durable agent workflows - -Environment variables: - AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint - AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name - AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint - AZURE_COSMOS_DATABASE_NAME - Database name -Optional: - AZURE_COSMOS_KEY - Account key (if not using Azure credentials) -""" - -import asyncio -import os -from typing import Any - -from agent_framework import WorkflowAgent, WorkflowCheckpoint -from agent_framework.azure import AzureOpenAIResponsesClient -from agent_framework.orchestrations import SequentialBuilder -from azure.identity.aio import AzureCliCredential -from dotenv import load_dotenv - -from agent_framework_azure_cosmos import CosmosCheckpointStorage, CosmosHistoryProvider - -load_dotenv() - - -async def main() -> None: - """Run the combined history + checkpoint sample.""" - project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") - deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if not project_endpoint or not deployment_name: - print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") - return - - if not cosmos_endpoint or not cosmos_database_name: - print("Please set AZURE_COSMOS_ENDPOINT and AZURE_COSMOS_DATABASE_NAME.") - return - - async with AzureCliCredential() as azure_credential: - cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential - - # CosmosHistoryProvider: stores conversation messages - # CosmosCheckpointStorage: stores workflow execution state - async with ( - CosmosHistoryProvider( - endpoint=cosmos_endpoint, - database_name=cosmos_database_name, - container_name="conversation-history", - credential=cosmos_credential, - ) as history_provider, - CosmosCheckpointStorage( - endpoint=cosmos_endpoint, - credential=cosmos_credential, - database_name=cosmos_database_name, - container_name="workflow-checkpoints", - ) as checkpoint_storage, - ): - # Create Azure AI Foundry agents - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=azure_credential, - ) - - assistant = client.as_agent( - name="assistant", - instructions="You are a helpful assistant. Keep responses brief.", - ) - - reviewer = client.as_agent( - name="reviewer", - instructions=( - "You are a reviewer. Provide a one-sentence " - "summary of the assistant's response." - ), - ) - - # Build a workflow with both history and checkpointing. - # Attach the history provider to the WorkflowAgent (outer agent) - # so conversation messages are persisted at the agent level. - workflow = SequentialBuilder( - participants=[assistant, reviewer], - ).build() - agent = WorkflowAgent( - workflow, - name="DurableAgent", - context_providers=[history_provider], - ) - - # --- First run --- - print("=== First Run ===\n") - session = agent.create_session() - - response = await agent.run( - "What are three benefits of cloud computing?", - session=session, - checkpoint_storage=checkpoint_storage, - ) - - for msg in response.messages: - speaker = msg.author_name or msg.role - print(f"[{speaker}]: {msg.text}") - - # Show what's persisted - checkpoints = await checkpoint_storage.list_checkpoints( - workflow_name=workflow.name, - ) - history = await history_provider.get_messages(session.session_id) - - print(f"\nConversation messages in Cosmos DB: {len(history)}") - print(f"Workflow checkpoints in Cosmos DB: {len(checkpoints)}") - - # --- Second run: conversation context is loaded from history --- - print("\n=== Second Run (with conversation context) ===\n") - - response2 = await agent.run( - "Can you elaborate on the first benefit?", - session=session, - checkpoint_storage=checkpoint_storage, - ) - - for msg in response2.messages: - speaker = msg.author_name or msg.role - print(f"[{speaker}]: {msg.text}") - - # Show updated state - latest: WorkflowCheckpoint | None = await checkpoint_storage.get_latest( - workflow_name=workflow.name, - ) - history2 = await history_provider.get_messages(session.session_id) - - print(f"\nConversation messages after 2 runs: {len(history2)}") - if latest: - print(f"Latest checkpoint: {latest.checkpoint_id}") - print(f" iteration_count: {latest.iteration_count}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/packages/azure-cosmos/samples/history_provider/__init__.py b/python/packages/azure-cosmos/samples/history_provider/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py deleted file mode 100644 index 5dc5c54b65..0000000000 --- a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_basic.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import os - -from agent_framework import Agent -from agent_framework.azure import CosmosHistoryProvider -from agent_framework.foundry import FoundryChatClient -from azure.identity.aio import AzureCliCredential -from dotenv import load_dotenv - -# Load environment variables from .env file. -load_dotenv() - -""" -This sample demonstrates CosmosHistoryProvider as an agent history provider. - -Key components: -- FoundryChatClient configured with an Azure AI project endpoint -- CosmosHistoryProvider configured for Cosmos DB-backed message history -- Provider-configured container name with session_id as partition key - -Environment variables: - FOUNDRY_PROJECT_ENDPOINT - FOUNDRY_MODEL - AZURE_COSMOS_ENDPOINT - AZURE_COSMOS_DATABASE_NAME - AZURE_COSMOS_CONTAINER_NAME -Optional: - AZURE_COSMOS_KEY -""" - - -async def main() -> None: - """Run the Cosmos history provider sample with an Agent.""" - project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT") - model = os.getenv("FOUNDRY_MODEL") - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if ( - not project_endpoint - or not model - or not cosmos_endpoint - or not cosmos_database_name - or not cosmos_container_name - ): - print( - "Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, " - "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." - ) - return - - # 1. Create an Azure credential and a CosmosHistoryProvider for agent context - async with ( - AzureCliCredential() as credential, - CosmosHistoryProvider( - endpoint=cosmos_endpoint, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - credential=cosmos_key or credential, - ) as history_provider, - # 2. Create an agent that uses Cosmos for persisted conversation history. - Agent( - client=FoundryChatClient( - project_endpoint=project_endpoint, - model=model, - credential=credential, - ), - name="CosmosHistoryAgent", - instructions="You are a helpful assistant that remembers prior turns.", - context_providers=[history_provider], - default_options={"store": False}, - ) as agent, - ): - # 3. Create a session (session_id is used as the partition key). - session = agent.create_session() - - # 4. Run a multi-turn conversation; history is persisted by CosmosHistoryProvider. - response1 = await agent.run("My name is Ada and I enjoy distributed systems.", session=session) - print(f"Assistant: {response1.text}") - - response2 = await agent.run("What do you remember about me?", session=session) - print(f"Assistant: {response2.text}") - print(f"Container: {history_provider.container_name}") - - -if __name__ == "__main__": - asyncio.run(main()) - -""" -Sample output: -Assistant: Nice to meet you, Ada! Distributed systems are a fascinating area. -Assistant: You told me your name is Ada and that you enjoy distributed systems. -Container: -""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py deleted file mode 100644 index f8e97e8d17..0000000000 --- a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_conversation_persistence.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# ruff: noqa: T201 - -import asyncio -import os - -from agent_framework import AgentSession -from agent_framework.azure import AzureOpenAIResponsesClient -from azure.identity.aio import AzureCliCredential -from dotenv import load_dotenv - -from agent_framework_azure_cosmos import CosmosHistoryProvider - -# Load environment variables from .env file. -load_dotenv() - -""" -This sample demonstrates persisting and resuming conversations across application -restarts using CosmosHistoryProvider as the persistent backend. - -Key components: -- Phase 1: Run a conversation and serialize the session with session.to_dict() -- Phase 2: Simulate an app restart — create new provider and agent instances, - restore the session with AgentSession.from_dict(), and continue the conversation -- Cosmos DB reloads the full message history, so the agent remembers everything - -Environment variables: - AZURE_AI_PROJECT_ENDPOINT - AZURE_AI_MODEL_DEPLOYMENT_NAME - AZURE_COSMOS_ENDPOINT - AZURE_COSMOS_DATABASE_NAME - AZURE_COSMOS_CONTAINER_NAME -Optional: - AZURE_COSMOS_KEY -""" - - -async def main() -> None: - """Run the conversation persistence sample.""" - project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") - deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if ( - not project_endpoint - or not deployment_name - or not cosmos_endpoint - or not cosmos_database_name - or not cosmos_container_name - ): - print( - "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " - "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." - ) - return - - # ── Phase 1: Initial conversation ── - - print("=== Phase 1: Initial conversation ===\n") - - # 1. Create credential, client, history provider, and agent. - async with AzureCliCredential() as credential: - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=credential, - ) - - async with ( - CosmosHistoryProvider( - endpoint=cosmos_endpoint, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - credential=cosmos_key or credential, - ) as history_provider, - client.as_agent( - name="PersistentAgent", - instructions="You are a helpful assistant that remembers prior turns.", - context_providers=[history_provider], - default_options={"store": False}, - ) as agent, - ): - # 2. Create a session and have a multi-turn conversation. - session = agent.create_session() - - response1 = await agent.run( - "My name is Ada. I'm building a distributed database in Rust.", session=session - ) - print("User: My name is Ada. I'm building a distributed database in Rust.") - print(f"Assistant: {response1.text}\n") - - response2 = await agent.run("The hardest part is the consensus algorithm.", session=session) - print("User: The hardest part is the consensus algorithm.") - print(f"Assistant: {response2.text}\n") - - # 3. Serialize the session state — this is what you'd persist to a database or file. - serialized_session = session.to_dict() - print(f"Session serialized. Session ID: {session.session_id}") - - # ── Phase 2: Simulate app restart ── - - print("\n=== Phase 2: Resuming after 'restart' ===\n") - - # 4. Create entirely new provider and agent instances (simulating a fresh process). - async with AzureCliCredential() as credential: - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=credential, - ) - - async with ( - CosmosHistoryProvider( - endpoint=cosmos_endpoint, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - credential=cosmos_key or credential, - ) as history_provider, - client.as_agent( - name="PersistentAgent", - instructions="You are a helpful assistant that remembers prior turns.", - context_providers=[history_provider], - default_options={"store": False}, - ) as agent, - ): - # 5. Restore the session from the serialized state. - restored_session = AgentSession.from_dict(serialized_session) - print(f"Session restored. Session ID: {restored_session.session_id}\n") - - # 6. Continue the conversation — history is reloaded from Cosmos DB. - response3 = await agent.run("What was I working on and what was the challenge?", session=restored_session) - print("User: What was I working on and what was the challenge?") - print(f"Assistant: {response3.text}\n") - - # 7. Verify messages are in Cosmos by reading them directly. - messages = await history_provider.get_messages(restored_session.session_id) - print(f"Messages stored in Cosmos DB: {len(messages)}") - for i, msg in enumerate(messages, 1): - print(f" {i}. [{msg.role}] {msg.text[:80]}...") - - -if __name__ == "__main__": - asyncio.run(main()) - -""" -Sample output: -=== Phase 1: Initial conversation === - -User: My name is Ada. I'm building a distributed database in Rust. -Assistant: That sounds like a great project, Ada! Rust is an excellent choice for ... - -User: The hardest part is the consensus algorithm. -Assistant: Consensus algorithms can be tricky! Are you looking at Raft, Paxos, or ... - -Session serialized. Session ID: - -=== Phase 2: Resuming after 'restart' === - -Session restored. Session ID: - -User: What was I working on and what was the challenge? -Assistant: You told me you're building a distributed database in Rust and that the hardest -part is the consensus algorithm. - -Messages stored in Cosmos DB: 6 - 1. [user] My name is Ada. I'm building a distributed database in Rust.... - 2. [assistant] That sounds like a great project, Ada! Rust is an excellent ch... - 3. [user] The hardest part is the consensus algorithm.... - 4. [assistant] Consensus algorithms can be tricky! Are you looking at Raft, Pa... - 5. [user] What was I working on and what was the challenge?... - 6. [assistant] You told me you're building a distributed database in Rust and ... -""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py deleted file mode 100644 index c97504fd40..0000000000 --- a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_messages.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# ruff: noqa: T201 - -import asyncio -import os - -from agent_framework.azure import AzureOpenAIResponsesClient -from azure.identity.aio import AzureCliCredential -from dotenv import load_dotenv - -from agent_framework_azure_cosmos import CosmosHistoryProvider - -# Load environment variables from .env file. -load_dotenv() - -""" -This sample demonstrates direct message history operations using -CosmosHistoryProvider — retrieving, displaying, and clearing stored messages. - -Key components: -- get_messages(session_id): Retrieve all stored messages as a chat transcript -- clear(session_id): Delete all messages for a session (e.g., GDPR compliance) -- Verifying that history is empty after clearing -- Running a new conversation in the same session after clearing - -Environment variables: - AZURE_AI_PROJECT_ENDPOINT - AZURE_AI_MODEL_DEPLOYMENT_NAME - AZURE_COSMOS_ENDPOINT - AZURE_COSMOS_DATABASE_NAME - AZURE_COSMOS_CONTAINER_NAME -Optional: - AZURE_COSMOS_KEY -""" - - -async def main() -> None: - """Run the messages history sample.""" - project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") - deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if ( - not project_endpoint - or not deployment_name - or not cosmos_endpoint - or not cosmos_database_name - or not cosmos_container_name - ): - print( - "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " - "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." - ) - return - - async with AzureCliCredential() as credential: - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=credential, - ) - - async with ( - CosmosHistoryProvider( - endpoint=cosmos_endpoint, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - credential=cosmos_key or credential, - ) as history_provider, - client.as_agent( - name="HistoryAgent", - instructions="You are a helpful assistant that remembers prior turns.", - context_providers=[history_provider], - default_options={"store": False}, - ) as agent, - ): - session = agent.create_session() - session_id = session.session_id - - # 1. Have a multi-turn conversation. - print("=== Building a conversation ===\n") - - queries = [ - "Hi! My favorite programming language is Python.", - "I also enjoy hiking in the mountains on weekends.", - "What do you know about me so far?", - ] - for query in queries: - response = await agent.run(query, session=session) - print(f"User: {query}") - print(f"Assistant: {response.text}\n") - - # 2. Retrieve and display the full message history as a transcript. - print("=== Chat transcript from Cosmos DB ===\n") - - messages = await history_provider.get_messages(session_id) - print(f"Total messages stored: {len(messages)}\n") - for i, msg in enumerate(messages, 1): - print(f" {i}. [{msg.role}] {msg.text[:100]}") - - # 3. Clear the session history. - print("\n=== Clearing session history ===\n") - - await history_provider.clear(session_id) - print(f"Cleared all messages for session: {session_id}") - - # 4. Verify history is empty. - remaining = await history_provider.get_messages(session_id) - print(f"Messages after clear: {len(remaining)}") - - # 5. Start a fresh conversation in the same session — agent has no memory. - print("\n=== Fresh conversation (same session, no memory) ===\n") - - response = await agent.run("What do you know about me?", session=session) - print("User: What do you know about me?") - print(f"Assistant: {response.text}") - - -if __name__ == "__main__": - asyncio.run(main()) - -""" -Sample output: -=== Building a conversation === - -User: Hi! My favorite programming language is Python. -Assistant: That's great! Python is a wonderful language. What do you like most about it? - -User: I also enjoy hiking in the mountains on weekends. -Assistant: Hiking sounds lovely! Do you have a favorite trail or mountain range? - -User: What do you know about me so far? -Assistant: You love Python as your favorite programming language and enjoy hiking in the mountains on weekends. - -=== Chat transcript from Cosmos DB === - -Total messages stored: 6 - - 1. [user] Hi! My favorite programming language is Python. - 2. [assistant] That's great! Python is a wonderful language. What do you like most about it? - 3. [user] I also enjoy hiking in the mountains on weekends. - 4. [assistant] Hiking sounds lovely! Do you have a favorite trail or mountain range? - 5. [user] What do you know about me so far? - 6. [assistant] You love Python as your favorite programming language and enjoy hiking ... - -=== Clearing session history === - -Cleared all messages for session: -Messages after clear: 0 - -=== Fresh conversation (same session, no memory) === - -User: What do you know about me? -Assistant: I don't have any information about you yet. Feel free to share anything you'd like! -""" diff --git a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py b/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py deleted file mode 100644 index 6772d41825..0000000000 --- a/python/packages/azure-cosmos/samples/history_provider/cosmos_history_sessions.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# ruff: noqa: T201 - -import asyncio -import os - -from agent_framework.azure import AzureOpenAIResponsesClient -from azure.identity.aio import AzureCliCredential -from dotenv import load_dotenv - -from agent_framework_azure_cosmos import CosmosHistoryProvider - -# Load environment variables from .env file. -load_dotenv() - -""" -This sample demonstrates multi-session and multi-tenant management using -CosmosHistoryProvider. Each tenant (user) gets isolated conversation sessions -stored in the same Cosmos DB container, partitioned by session_id. - -Key components: -- Per-tenant session isolation using prefixed session IDs -- list_sessions(): Enumerate all stored sessions across tenants -- Switching between sessions for different users -- Resuming a specific user's session — verifying data isolation - -Environment variables: - AZURE_AI_PROJECT_ENDPOINT - AZURE_AI_MODEL_DEPLOYMENT_NAME - AZURE_COSMOS_ENDPOINT - AZURE_COSMOS_DATABASE_NAME - AZURE_COSMOS_CONTAINER_NAME -Optional: - AZURE_COSMOS_KEY -""" - - -async def main() -> None: - """Run the session management sample.""" - project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT") - deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if ( - not project_endpoint - or not deployment_name - or not cosmos_endpoint - or not cosmos_database_name - or not cosmos_container_name - ): - print( - "Please set AZURE_AI_PROJECT_ENDPOINT, AZURE_AI_MODEL_DEPLOYMENT_NAME, " - "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." - ) - return - - async with AzureCliCredential() as credential: - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=credential, - ) - - async with ( - CosmosHistoryProvider( - endpoint=cosmos_endpoint, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - credential=cosmos_key or credential, - ) as history_provider, - client.as_agent( - name="MultiTenantAgent", - instructions="You are a helpful assistant that remembers prior turns.", - context_providers=[history_provider], - default_options={"store": False}, - ) as agent, - ): - # 1. Tenant "alice" starts a conversation about travel. - print("=== Tenant: Alice — Travel conversation ===\n") - - alice_session = agent.create_session(session_id="tenant-alice-session-1") - - response = await agent.run( - "Hi! I'm planning a trip to Italy. I love Renaissance art.", session=alice_session - ) - print("Alice: I'm planning a trip to Italy. I love Renaissance art.") - print(f"Assistant: {response.text}\n") - - response = await agent.run("Which museums should I visit in Florence?", session=alice_session) - print("Alice: Which museums should I visit in Florence?") - print(f"Assistant: {response.text}\n") - - # 2. Tenant "bob" starts a separate conversation about cooking. - print("=== Tenant: Bob — Cooking conversation ===\n") - - bob_session = agent.create_session(session_id="tenant-bob-session-1") - - response = await agent.run( - "Hey! I'm learning to cook Thai food. I just made pad thai.", session=bob_session - ) - print("Bob: I'm learning to cook Thai food. I just made pad thai.") - print(f"Assistant: {response.text}\n") - - response = await agent.run("What Thai dish should I try next?", session=bob_session) - print("Bob: What Thai dish should I try next?") - print(f"Assistant: {response.text}\n") - - # 3. List all sessions stored in Cosmos DB. - print("=== Listing all sessions ===\n") - - sessions = await history_provider.list_sessions() - print(f"Found {len(sessions)} session(s):") - for sid in sessions: - print(f" - {sid}") - - # 4. Resume Alice's session — verify she gets her travel context back. - print("\n=== Resuming Alice's session ===\n") - - alice_resumed = agent.create_session(session_id="tenant-alice-session-1") - - response = await agent.run("What were we discussing?", session=alice_resumed) - print("Alice: What were we discussing?") - print(f"Assistant: {response.text}\n") - - # 5. Resume Bob's session — verify he gets his cooking context back. - print("=== Resuming Bob's session ===\n") - - bob_resumed = agent.create_session(session_id="tenant-bob-session-1") - - response = await agent.run("What was the last dish I mentioned?", session=bob_resumed) - print("Bob: What was the last dish I mentioned?") - print(f"Assistant: {response.text}\n") - - # 6. Show per-session message counts. - print("=== Per-session message counts ===\n") - - alice_messages = await history_provider.get_messages("tenant-alice-session-1") - bob_messages = await history_provider.get_messages("tenant-bob-session-1") - print(f"Alice's session: {len(alice_messages)} messages") - print(f"Bob's session: {len(bob_messages)} messages") - - # 7. Clean up: clear both sessions. - print("\n=== Cleaning up ===\n") - - await history_provider.clear("tenant-alice-session-1") - await history_provider.clear("tenant-bob-session-1") - print("Cleared Alice's and Bob's sessions.") - - -if __name__ == "__main__": - asyncio.run(main()) - -""" -Sample output: -=== Tenant: Alice — Travel conversation === - -Alice: I'm planning a trip to Italy. I love Renaissance art. -Assistant: Italy is a dream for Renaissance art lovers! Florence, Rome, and Venice ... - -Alice: Which museums should I visit in Florence? -Assistant: In Florence, the Uffizi Gallery is a must — it has Botticelli's Birth of Venus ... - -=== Tenant: Bob — Cooking conversation === - -Bob: I'm learning to cook Thai food. I just made pad thai. -Assistant: Pad thai is a great start! How did it turn out? - -Bob: What Thai dish should I try next? -Assistant: I'd suggest trying green curry or tom yum soup — both are classic Thai dishes ... - -=== Listing all sessions === - -Found 2 session(s): - - tenant-alice-session-1 - - tenant-bob-session-1 - -=== Resuming Alice's session === - -Alice: What were we discussing? -Assistant: We were discussing your trip to Italy and your love for Renaissance art ... - -=== Resuming Bob's session === - -Bob: What was the last dish I mentioned? -Assistant: You mentioned pad thai — it was the dish you just made! - -=== Per-session message counts === - -Alice's session: 6 messages -Bob's session: 6 messages - -=== Cleaning up === - -Cleared Alice's and Bob's sessions. -""" diff --git a/python/samples/02-agents/conversations/README.md b/python/samples/02-agents/conversations/README.md index becacd9134..bbfb078659 100644 --- a/python/samples/02-agents/conversations/README.md +++ b/python/samples/02-agents/conversations/README.md @@ -9,6 +9,9 @@ These samples demonstrate different approaches to managing conversation history | [`suspend_resume_session.py`](suspend_resume_session.py) | Suspend and resume conversation sessions, comparing service-managed sessions (Azure AI Foundry) with in-memory sessions (OpenAI). | | [`custom_history_provider.py`](custom_history_provider.py) | Implement a custom history provider by extending `HistoryProvider`, enabling conversation persistence in your preferred storage backend. | | [`cosmos_history_provider.py`](cosmos_history_provider.py) | Use Azure Cosmos DB as a history provider for durable conversation storage with `CosmosHistoryProvider`. | +| [`cosmos_history_provider_conversation_persistence.py`](cosmos_history_provider_conversation_persistence.py) | Persist and resume conversations across application restarts using `CosmosHistoryProvider` — serialize session state, restore it, and continue with full Cosmos DB history. | +| [`cosmos_history_provider_messages.py`](cosmos_history_provider_messages.py) | Direct message history operations — retrieve stored messages as a transcript, clear session history, and verify data deletion. | +| [`cosmos_history_provider_sessions.py`](cosmos_history_provider_sessions.py) | Multi-session and multi-tenant management — per-tenant session isolation, `list_sessions()` to enumerate, switch between sessions, and resume specific conversations. | | [`redis_history_provider.py`](redis_history_provider.py) | Use Redis as a history provider for persistent conversation history storage across sessions. | ## Prerequisites @@ -22,7 +25,7 @@ These samples demonstrate different approaches to managing conversation history **For `custom_history_provider.py`:** - `OPENAI_API_KEY`: Your OpenAI API key -**For `cosmos_history_provider.py`:** +**For Cosmos DB samples (`cosmos_history_provider*.py`):** - `FOUNDRY_PROJECT_ENDPOINT`: Your Azure AI Foundry project endpoint - `FOUNDRY_MODEL`: The Foundry model deployment name - `AZURE_COSMOS_ENDPOINT`: Your Azure Cosmos DB account endpoint diff --git a/python/samples/02-agents/conversations/cosmos_history_provider_conversation_persistence.py b/python/samples/02-agents/conversations/cosmos_history_provider_conversation_persistence.py new file mode 100644 index 0000000000..ef2b444d28 --- /dev/null +++ b/python/samples/02-agents/conversations/cosmos_history_provider_conversation_persistence.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework import Agent, AgentSession +from agent_framework.foundry import FoundryChatClient +from agent_framework_azure_cosmos import CosmosHistoryProvider +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates persisting and resuming conversations across application +restarts using CosmosHistoryProvider as the persistent backend. + +Key components: +- Phase 1: Run a conversation and serialize the session with session.to_dict() +- Phase 2: Simulate an app restart — create new provider and agent instances, + restore the session with AgentSession.from_dict(), and continue the conversation +- Cosmos DB reloads the full message history, so the agent remembers everything + +Environment variables: + FOUNDRY_PROJECT_ENDPOINT + FOUNDRY_MODEL + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the conversation persistence sample.""" + project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT") + model = os.getenv("FOUNDRY_MODEL") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not model + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + # ── Phase 1: Initial conversation ── + + print("=== Phase 1: Initial conversation ===\n") + + async with ( + AzureCliCredential() as credential, + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + Agent( + client=FoundryChatClient( + project_endpoint=project_endpoint, + model=model, + credential=credential, + ), + name="PersistentAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + session = agent.create_session() + + response1 = await agent.run( + "My name is Ada. I'm building a distributed database in Rust.", session=session + ) + print("User: My name is Ada. I'm building a distributed database in Rust.") + print(f"Assistant: {response1.text}\n") + + response2 = await agent.run("The hardest part is the consensus algorithm.", session=session) + print("User: The hardest part is the consensus algorithm.") + print(f"Assistant: {response2.text}\n") + + serialized_session = session.to_dict() + print(f"Session serialized. Session ID: {session.session_id}") + + # ── Phase 2: Simulate app restart ── + + print("\n=== Phase 2: Resuming after 'restart' ===\n") + + async with ( + AzureCliCredential() as credential, + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + Agent( + client=FoundryChatClient( + project_endpoint=project_endpoint, + model=model, + credential=credential, + ), + name="PersistentAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + restored_session = AgentSession.from_dict(serialized_session) + print(f"Session restored. Session ID: {restored_session.session_id}\n") + + response3 = await agent.run("What was I working on and what was the challenge?", session=restored_session) + print("User: What was I working on and what was the challenge?") + print(f"Assistant: {response3.text}\n") + + messages = await history_provider.get_messages(restored_session.session_id) + print(f"Messages stored in Cosmos DB: {len(messages)}") + for i, msg in enumerate(messages, 1): + print(f" {i}. [{msg.role}] {msg.text[:80]}...") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Phase 1: Initial conversation === + +User: My name is Ada. I'm building a distributed database in Rust. +Assistant: That sounds like a great project, Ada! Rust is an excellent choice for ... + +User: The hardest part is the consensus algorithm. +Assistant: Consensus algorithms can be tricky! Are you looking at Raft, Paxos, or ... + +Session serialized. Session ID: + +=== Phase 2: Resuming after 'restart' === + +Session restored. Session ID: + +User: What was I working on and what was the challenge? +Assistant: You told me you're building a distributed database in Rust and that the hardest +part is the consensus algorithm. + +Messages stored in Cosmos DB: 6 + 1. [user] My name is Ada. I'm building a distributed database in Rust.... + 2. [assistant] That sounds like a great project, Ada! Rust is an excellent ch... + 3. [user] The hardest part is the consensus algorithm.... + 4. [assistant] Consensus algorithms can be tricky! Are you looking at Raft, Pa... + 5. [user] What was I working on and what was the challenge?... + 6. [assistant] You told me you're building a distributed database in Rust and ... +""" diff --git a/python/samples/02-agents/conversations/cosmos_history_provider_messages.py b/python/samples/02-agents/conversations/cosmos_history_provider_messages.py new file mode 100644 index 0000000000..9f8e7c9164 --- /dev/null +++ b/python/samples/02-agents/conversations/cosmos_history_provider_messages.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework import Agent +from agent_framework.foundry import FoundryChatClient +from agent_framework_azure_cosmos import CosmosHistoryProvider +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates direct message history operations using +CosmosHistoryProvider — retrieving, displaying, and clearing stored messages. + +Key components: +- get_messages(session_id): Retrieve all stored messages as a chat transcript +- clear(session_id): Delete all messages for a session (e.g., GDPR compliance) +- Verifying that history is empty after clearing +- Running a new conversation in the same session after clearing + +Environment variables: + FOUNDRY_PROJECT_ENDPOINT + FOUNDRY_MODEL + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the messages history sample.""" + project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT") + model = os.getenv("FOUNDRY_MODEL") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not model + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + async with ( + AzureCliCredential() as credential, + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + Agent( + client=FoundryChatClient( + project_endpoint=project_endpoint, + model=model, + credential=credential, + ), + name="HistoryAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + session = agent.create_session() + session_id = session.session_id + + # 1. Have a multi-turn conversation. + print("=== Building a conversation ===\n") + + queries = [ + "Hi! My favorite programming language is Python.", + "I also enjoy hiking in the mountains on weekends.", + "What do you know about me so far?", + ] + for query in queries: + response = await agent.run(query, session=session) + print(f"User: {query}") + print(f"Assistant: {response.text}\n") + + # 2. Retrieve and display the full message history as a transcript. + print("=== Chat transcript from Cosmos DB ===\n") + + messages = await history_provider.get_messages(session_id) + print(f"Total messages stored: {len(messages)}\n") + for i, msg in enumerate(messages, 1): + print(f" {i}. [{msg.role}] {msg.text[:100]}") + + # 3. Clear the session history. + print("\n=== Clearing session history ===\n") + + await history_provider.clear(session_id) + print(f"Cleared all messages for session: {session_id}") + + # 4. Verify history is empty. + remaining = await history_provider.get_messages(session_id) + print(f"Messages after clear: {len(remaining)}") + + # 5. Start a fresh conversation in the same session — agent has no memory. + print("\n=== Fresh conversation (same session, no memory) ===\n") + + response = await agent.run("What do you know about me?", session=session) + print("User: What do you know about me?") + print(f"Assistant: {response.text}") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Building a conversation === + +User: Hi! My favorite programming language is Python. +Assistant: That's great! Python is a wonderful language. What do you like most about it? + +User: I also enjoy hiking in the mountains on weekends. +Assistant: Hiking sounds lovely! Do you have a favorite trail or mountain range? + +User: What do you know about me so far? +Assistant: You love Python as your favorite programming language and enjoy hiking in the mountains on weekends. + +=== Chat transcript from Cosmos DB === + +Total messages stored: 6 + + 1. [user] Hi! My favorite programming language is Python. + 2. [assistant] That's great! Python is a wonderful language. What do you like most about it? + 3. [user] I also enjoy hiking in the mountains on weekends. + 4. [assistant] Hiking sounds lovely! Do you have a favorite trail or mountain range? + 5. [user] What do you know about me so far? + 6. [assistant] You love Python as your favorite programming language and enjoy hiking ... + +=== Clearing session history === + +Cleared all messages for session: +Messages after clear: 0 + +=== Fresh conversation (same session, no memory) === + +User: What do you know about me? +Assistant: I don't have any information about you yet. Feel free to share anything you'd like! +""" diff --git a/python/samples/02-agents/conversations/cosmos_history_provider_sessions.py b/python/samples/02-agents/conversations/cosmos_history_provider_sessions.py new file mode 100644 index 0000000000..2d1861e503 --- /dev/null +++ b/python/samples/02-agents/conversations/cosmos_history_provider_sessions.py @@ -0,0 +1,197 @@ +# Copyright (c) Microsoft. All rights reserved. +# ruff: noqa: T201 + +import asyncio +import os + +from agent_framework import Agent +from agent_framework.foundry import FoundryChatClient +from agent_framework_azure_cosmos import CosmosHistoryProvider +from azure.identity.aio import AzureCliCredential +from dotenv import load_dotenv + +# Load environment variables from .env file. +load_dotenv() + +""" +This sample demonstrates multi-session and multi-tenant management using +CosmosHistoryProvider. Each tenant (user) gets isolated conversation sessions +stored in the same Cosmos DB container, partitioned by session_id. + +Key components: +- Per-tenant session isolation using prefixed session IDs +- list_sessions(): Enumerate all stored sessions across tenants +- Switching between sessions for different users +- Resuming a specific user's session — verifying data isolation + +Environment variables: + FOUNDRY_PROJECT_ENDPOINT + FOUNDRY_MODEL + AZURE_COSMOS_ENDPOINT + AZURE_COSMOS_DATABASE_NAME + AZURE_COSMOS_CONTAINER_NAME +Optional: + AZURE_COSMOS_KEY +""" + + +async def main() -> None: + """Run the session management sample.""" + project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT") + model = os.getenv("FOUNDRY_MODEL") + cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") + cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") + cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") + cosmos_key = os.getenv("AZURE_COSMOS_KEY") + + if ( + not project_endpoint + or not model + or not cosmos_endpoint + or not cosmos_database_name + or not cosmos_container_name + ): + print( + "Please set FOUNDRY_PROJECT_ENDPOINT, FOUNDRY_MODEL, " + "AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, and AZURE_COSMOS_CONTAINER_NAME." + ) + return + + async with ( + AzureCliCredential() as credential, + CosmosHistoryProvider( + endpoint=cosmos_endpoint, + database_name=cosmos_database_name, + container_name=cosmos_container_name, + credential=cosmos_key or credential, + ) as history_provider, + Agent( + client=FoundryChatClient( + project_endpoint=project_endpoint, + model=model, + credential=credential, + ), + name="MultiTenantAgent", + instructions="You are a helpful assistant that remembers prior turns.", + context_providers=[history_provider], + default_options={"store": False}, + ) as agent, + ): + # 1. Tenant "alice" starts a conversation about travel. + print("=== Tenant: Alice — Travel conversation ===\n") + + alice_session = agent.create_session(session_id="tenant-alice-session-1") + + response = await agent.run( + "Hi! I'm planning a trip to Italy. I love Renaissance art.", session=alice_session + ) + print("Alice: I'm planning a trip to Italy. I love Renaissance art.") + print(f"Assistant: {response.text}\n") + + response = await agent.run("Which museums should I visit in Florence?", session=alice_session) + print("Alice: Which museums should I visit in Florence?") + print(f"Assistant: {response.text}\n") + + # 2. Tenant "bob" starts a separate conversation about cooking. + print("=== Tenant: Bob — Cooking conversation ===\n") + + bob_session = agent.create_session(session_id="tenant-bob-session-1") + + response = await agent.run( + "Hey! I'm learning to cook Thai food. I just made pad thai.", session=bob_session + ) + print("Bob: I'm learning to cook Thai food. I just made pad thai.") + print(f"Assistant: {response.text}\n") + + response = await agent.run("What Thai dish should I try next?", session=bob_session) + print("Bob: What Thai dish should I try next?") + print(f"Assistant: {response.text}\n") + + # 3. List all sessions stored in Cosmos DB. + print("=== Listing all sessions ===\n") + + sessions = await history_provider.list_sessions() + print(f"Found {len(sessions)} session(s):") + for sid in sessions: + print(f" - {sid}") + + # 4. Resume Alice's session — verify she gets her travel context back. + print("\n=== Resuming Alice's session ===\n") + + alice_resumed = agent.create_session(session_id="tenant-alice-session-1") + + response = await agent.run("What were we discussing?", session=alice_resumed) + print("Alice: What were we discussing?") + print(f"Assistant: {response.text}\n") + + # 5. Resume Bob's session — verify he gets his cooking context back. + print("=== Resuming Bob's session ===\n") + + bob_resumed = agent.create_session(session_id="tenant-bob-session-1") + + response = await agent.run("What was the last dish I mentioned?", session=bob_resumed) + print("Bob: What was the last dish I mentioned?") + print(f"Assistant: {response.text}\n") + + # 6. Show per-session message counts. + print("=== Per-session message counts ===\n") + + alice_messages = await history_provider.get_messages("tenant-alice-session-1") + bob_messages = await history_provider.get_messages("tenant-bob-session-1") + print(f"Alice's session: {len(alice_messages)} messages") + print(f"Bob's session: {len(bob_messages)} messages") + + # 7. Clean up: clear both sessions. + print("\n=== Cleaning up ===\n") + + await history_provider.clear("tenant-alice-session-1") + await history_provider.clear("tenant-bob-session-1") + print("Cleared Alice's and Bob's sessions.") + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +Sample output: +=== Tenant: Alice — Travel conversation === + +Alice: I'm planning a trip to Italy. I love Renaissance art. +Assistant: Italy is a dream for Renaissance art lovers! Florence, Rome, and Venice ... + +Alice: Which museums should I visit in Florence? +Assistant: In Florence, the Uffizi Gallery is a must — it has Botticelli's Birth of Venus ... + +=== Tenant: Bob — Cooking conversation === + +Bob: I'm learning to cook Thai food. I just made pad thai. +Assistant: Pad thai is a great start! How did it turn out? + +Bob: What Thai dish should I try next? +Assistant: I'd suggest trying green curry or tom yum soup — both are classic Thai dishes ... + +=== Listing all sessions === + +Found 2 session(s): + - tenant-alice-session-1 + - tenant-bob-session-1 + +=== Resuming Alice's session === + +Alice: What were we discussing? +Assistant: We were discussing your trip to Italy and your love for Renaissance art ... + +=== Resuming Bob's session === + +Bob: What was the last dish I mentioned? +Assistant: You mentioned pad thai — it was the dish you just made! + +=== Per-session message counts === + +Alice's session: 6 messages +Bob's session: 6 messages + +=== Cleaning up === + +Cleared Alice's and Bob's sessions. +""" diff --git a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py b/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py deleted file mode 100644 index 7c19d114a0..0000000000 --- a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing_foundry.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. -# ruff: noqa: T201 - -"""Sample: Workflow Checkpointing with Cosmos DB and Azure AI Foundry. - -Purpose: -This sample demonstrates how to use CosmosCheckpointStorage with agents built -on Azure AI Foundry (via AzureOpenAIResponsesClient). It shows a multi-agent -workflow where checkpoint state is persisted to Cosmos DB, enabling durable -pause-and-resume across process restarts. - -What you learn: -- How to wire CosmosCheckpointStorage with AzureOpenAIResponsesClient agents -- How to combine session history with workflow checkpointing -- How to resume a workflow-as-agent from a Cosmos DB checkpoint - -Key concepts: -- AgentSession: Maintains conversation history across agent invocations -- CosmosCheckpointStorage: Persists workflow execution state in Cosmos DB -- These are complementary: sessions track conversation, checkpoints track workflow state - -Environment variables: - AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint - AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name - AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint - AZURE_COSMOS_DATABASE_NAME - Database name - AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints -Optional: - AZURE_COSMOS_KEY - Account key (if not using Azure credentials) -""" - -import asyncio -import os -from typing import Any - -from agent_framework.azure import AzureOpenAIResponsesClient -from agent_framework.orchestrations import SequentialBuilder -from azure.identity.aio import AzureCliCredential -from dotenv import load_dotenv - -from agent_framework_azure_cosmos import CosmosCheckpointStorage - -load_dotenv() - - -async def main() -> None: - """Run the Azure AI Foundry + Cosmos DB checkpointing sample.""" - project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT",) - deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") - cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") - cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") - cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") - cosmos_key = os.getenv("AZURE_COSMOS_KEY") - - if not project_endpoint or not deployment_name: - print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") - return - - if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: - print( - "Please set AZURE_COSMOS_ENDPOINT, AZURE_COSMOS_DATABASE_NAME, " - "and AZURE_COSMOS_CONTAINER_NAME." - ) - return - - # Use a single AzureCliCredential for both Cosmos and Foundry, - # properly closed via async context manager. - async with AzureCliCredential() as azure_credential: - cosmos_credential: Any = cosmos_key if cosmos_key else azure_credential - - async with CosmosCheckpointStorage( - endpoint=cosmos_endpoint, - credential=cosmos_credential, - database_name=cosmos_database_name, - container_name=cosmos_container_name, - ) as checkpoint_storage: - # Create Azure AI Foundry agents - client = AzureOpenAIResponsesClient( - project_endpoint=project_endpoint, - deployment_name=deployment_name, - credential=azure_credential, - ) - - assistant = client.as_agent( - name="assistant", - instructions="You are a helpful assistant. Keep responses brief.", - ) - - reviewer = client.as_agent( - name="reviewer", - instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", - ) - - # Build a sequential workflow and wrap it as an agent - workflow = SequentialBuilder(participants=[assistant, reviewer]).build() - agent = workflow.as_agent(name="FoundryCheckpointedAgent") - - # --- First run: execute with Cosmos DB checkpointing --- - print("=== First Run ===\n") - - session = agent.create_session() - query = "What are the benefits of renewable energy?" - print(f"User: {query}") - - response = await agent.run(query, session=session, checkpoint_storage=checkpoint_storage) - - for msg in response.messages: - speaker = msg.author_name or msg.role - print(f"[{speaker}]: {msg.text}") - - # Show checkpoints persisted in Cosmos DB - checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) - print(f"\nCheckpoints in Cosmos DB: {len(checkpoints)}") - for i, cp in enumerate(checkpoints[:5], 1): - print(f" {i}. {cp.checkpoint_id} (iteration={cp.iteration_count})") - - # --- Second run: continue conversation with checkpoint history --- - print("\n=== Second Run (continuing conversation) ===\n") - - query2 = "Can you elaborate on the economic benefits?" - print(f"User: {query2}") - - response2 = await agent.run(query2, session=session, checkpoint_storage=checkpoint_storage) - - for msg in response2.messages: - speaker = msg.author_name or msg.role - print(f"[{speaker}]: {msg.text}") - - # Show total checkpoints - all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name) - print(f"\nTotal checkpoints after two runs: {len(all_checkpoints)}") - - # Get latest checkpoint - latest = await checkpoint_storage.get_latest(workflow_name=workflow.name) - if latest: - print(f"Latest checkpoint: {latest.checkpoint_id}") - print(f" iteration_count: {latest.iteration_count}") - print(f" timestamp: {latest.timestamp}") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py b/python/samples/03-workflows/checkpoint/cosmos_workflow_checkpointing.py similarity index 100% rename from python/samples/02-agents/conversations/cosmos_workflow_checkpointing.py rename to python/samples/03-workflows/checkpoint/cosmos_workflow_checkpointing.py diff --git a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py b/python/samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py similarity index 86% rename from python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py rename to python/samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py index 37fd6c986a..49c3e779f9 100644 --- a/python/packages/azure-cosmos/samples/checkpoint_storage/cosmos_checkpoint_foundry.py +++ b/python/samples/03-workflows/checkpoint/cosmos_workflow_checkpointing_foundry.py @@ -5,12 +5,12 @@ Purpose: This sample demonstrates how to use CosmosCheckpointStorage with agents built -on Azure AI Foundry (via AzureOpenAIResponsesClient). It shows a multi-agent +on Azure AI Foundry (via FoundryChatClient). It shows a multi-agent workflow where checkpoint state is persisted to Cosmos DB, enabling durable pause-and-resume across process restarts. What you learn: -- How to wire CosmosCheckpointStorage with AzureOpenAIResponsesClient agents +- How to wire CosmosCheckpointStorage with FoundryChatClient agents - How to combine session history with workflow checkpointing - How to resume a workflow-as-agent from a Cosmos DB checkpoint @@ -20,8 +20,8 @@ - These are complementary: sessions track conversation, checkpoints track workflow state Environment variables: - AZURE_AI_PROJECT_ENDPOINT - Azure AI Foundry project endpoint - AZURE_AI_MODEL_DEPLOYMENT_NAME - Model deployment name + FOUNDRY_PROJECT_ENDPOINT - Azure AI Foundry project endpoint + FOUNDRY_MODEL - Model deployment name AZURE_COSMOS_ENDPOINT - Cosmos DB account endpoint AZURE_COSMOS_DATABASE_NAME - Database name AZURE_COSMOS_CONTAINER_NAME - Container name for checkpoints @@ -33,27 +33,27 @@ import os from typing import Any -from agent_framework.azure import AzureOpenAIResponsesClient +from agent_framework import Agent +from agent_framework.foundry import FoundryChatClient from agent_framework.orchestrations import SequentialBuilder +from agent_framework_azure_cosmos import CosmosCheckpointStorage from azure.identity.aio import AzureCliCredential from dotenv import load_dotenv -from agent_framework_azure_cosmos import CosmosCheckpointStorage - load_dotenv() async def main() -> None: """Run the Azure AI Foundry + Cosmos DB checkpointing sample.""" - project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT",) - deployment_name = os.getenv("AZURE_AI_MODEL_DEPLOYMENT_NAME") + project_endpoint = os.getenv("FOUNDRY_PROJECT_ENDPOINT") + model = os.getenv("FOUNDRY_MODEL") cosmos_endpoint = os.getenv("AZURE_COSMOS_ENDPOINT") cosmos_database_name = os.getenv("AZURE_COSMOS_DATABASE_NAME") cosmos_container_name = os.getenv("AZURE_COSMOS_CONTAINER_NAME") cosmos_key = os.getenv("AZURE_COSMOS_KEY") - if not project_endpoint or not deployment_name: - print("Please set AZURE_AI_PROJECT_ENDPOINT and AZURE_AI_MODEL_DEPLOYMENT_NAME.") + if not project_endpoint or not model: + print("Please set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL.") return if not cosmos_endpoint or not cosmos_database_name or not cosmos_container_name: @@ -75,20 +75,22 @@ async def main() -> None: container_name=cosmos_container_name, ) as checkpoint_storage: # Create Azure AI Foundry agents - client = AzureOpenAIResponsesClient( + client = FoundryChatClient( project_endpoint=project_endpoint, - deployment_name=deployment_name, + model=model, credential=azure_credential, ) - assistant = client.as_agent( + assistant = Agent( name="assistant", instructions="You are a helpful assistant. Keep responses brief.", + client=client, ) - reviewer = client.as_agent( + reviewer = Agent( name="reviewer", instructions="You are a reviewer. Provide a one-sentence summary of the assistant's response.", + client=client, ) # Build a sequential workflow and wrap it as an agent From 3ed041f0a58ac8e4bf91c6b5d5978e2148b5558f Mon Sep 17 00:00:00 2001 From: Aayush Kataria Date: Wed, 8 Apr 2026 11:11:30 -0700 Subject: [PATCH 8/8] Resolving comments --- .../_checkpoint_storage.py | 12 +++++++++--- python/samples/03-workflows/README.md | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py index 08db5d51b9..4544311fd9 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_checkpoint_storage.py @@ -51,7 +51,9 @@ class CosmosCheckpointStorage: if they do not already exist. The container uses partition key ``/workflow_name``. - Example using managed identity / RBAC:: + Example using managed identity / RBAC: + + .. code-block:: python from azure.identity.aio import DefaultAzureCredential from agent_framework_azure_cosmos import CosmosCheckpointStorage @@ -63,7 +65,9 @@ class CosmosCheckpointStorage: container_name="checkpoints", ) - Example using account key:: + Example using account key: + + .. code-block:: python storage = CosmosCheckpointStorage( endpoint="https://my-account.documents.azure.com:443/", @@ -72,7 +76,9 @@ class CosmosCheckpointStorage: container_name="checkpoints", ) - Then use with a workflow builder:: + Then use with a workflow builder: + + .. code-block:: python workflow = WorkflowBuilder( start_executor=start, diff --git a/python/samples/03-workflows/README.md b/python/samples/03-workflows/README.md index 1fd5ab01fc..ae3292a07a 100644 --- a/python/samples/03-workflows/README.md +++ b/python/samples/03-workflows/README.md @@ -52,6 +52,8 @@ Once comfortable with these, explore the rest of the samples below. | Checkpointed Sub-Workflow | [checkpoint/sub_workflow_checkpoint.py](./checkpoint/sub_workflow_checkpoint.py) | Save and resume a sub-workflow that pauses for human approval | | Handoff + Tool Approval Resume | [orchestrations/handoff_with_tool_approval_checkpoint_resume.py](./orchestrations/handoff_with_tool_approval_checkpoint_resume.py) | Handoff workflow that captures tool-call approvals in checkpoints and resumes with human decisions | | Workflow as Agent Checkpoint | [checkpoint/workflow_as_agent_checkpoint.py](./checkpoint/workflow_as_agent_checkpoint.py) | Enable checkpointing when using workflow.as_agent() with checkpoint_storage parameter | +| Cosmos DB Checkpoint Storage | [checkpoint/cosmos_workflow_checkpointing.py](./checkpoint/cosmos_workflow_checkpointing.py) | Use `CosmosCheckpointStorage` for durable workflow checkpointing backed by Azure Cosmos DB NoSQL | +| Cosmos DB + Foundry Checkpoint | [checkpoint/cosmos_workflow_checkpointing_foundry.py](./checkpoint/cosmos_workflow_checkpointing_foundry.py) | Multi-agent workflow using `FoundryChatClient` with `CosmosCheckpointStorage` for durable pause/resume | ### composition