diff --git a/examples/session_stores/README.md b/examples/session_stores/README.md index 2cbaa56b..d5cff423 100644 --- a/examples/session_stores/README.md +++ b/examples/session_stores/README.md @@ -87,6 +87,18 @@ through the relevant items below. - Add a retention job (`DELETE WHERE mtime < ...`) — the table grows unbounded. +### MongoDB + +- Size the `AsyncMongoClient` connection pool for expected concurrent + sessions; don't share a pool with request-handler code that holds + connections. +- The summary sidecar is read-fold-written inside `append()`. The adapter + serializes per-session updates with an `asyncio.Lock`, but multi-process + writers against the same session would still race — pin a session to a + single writer or layer your own coordination on top. +- Implement retention via a TTL index on `mtime` or a scheduled + `delete_many` — both collections grow unbounded. + --- ## S3 — `s3_session_store.py` @@ -410,3 +422,141 @@ SESSION_STORE_POSTGRES_URL=postgresql://postgres:postgres@localhost:5432/postgre Each run creates a random-suffixed table and `DROP`s it on teardown, so the target database is left clean. + +--- + +## MongoDB — `mongodb_session_store.py` + +Backed by the official [`pymongo`](https://www.mongodb.com/docs/drivers/pymongo/) +driver via its stable async API (`pymongo.AsyncMongoClient`, introduced in +pymongo 4.13). + +### Installation + +```bash +pip install claude-agent-sdk 'pymongo>=4.13' +``` + +### Usage + +```python +from pymongo import AsyncMongoClient +from claude_agent_sdk import ClaudeAgentOptions, ResultMessage, query + +from mongodb_session_store import MongoDBSessionStore + +client = AsyncMongoClient("mongodb://localhost:27017") +store = MongoDBSessionStore(client=client, db_name="claude") +await store.create_schema() # idempotent createIndexes + +async for message in query( + prompt="Hello!", + options=ClaudeAgentOptions(session_store=store), +): + if isinstance(message, ResultMessage) and message.subtype == "success": + print(message.result) +``` + +### Schema + +Two collections share a single database. Entries — one document per +transcript entry, ordered by the server-assigned `_id` (`ObjectId`): + +```python +{ + "_id": ObjectId, + "project_key": str, + "session_id": str, + "subpath": str, # "" sentinel for the main transcript + "entry": , + "mtime": int, # Unix epoch ms, write-time stamp +} +``` + +Summaries — one document per main session, maintained incrementally inside +`append()` via `fold_session_summary`: + +```python +{ + "_id": {"project_key": str, "session_id": str}, + "mtime": int, # epoch ms (same clock as entries) + "data": , # SDK-owned summary state, persisted verbatim +} +``` + +`create_schema()` creates three indexes — `(project_key, session_id, subpath, +_id)` covering `load()`/`delete()`/`list_subkeys()`, `(project_key, subpath, +mtime DESC)` covering `list_sessions()`, and `(_id.project_key, mtime DESC)` +on the summaries collection. `append()` is a single `insert_many` plus an +atomic summary upsert; `load()` is `find().sort("_id", 1)`. + +### Why a summary sidecar? + +Unlike the S3, Redis, and Postgres reference adapters in this directory, +this adapter implements the optional `list_session_summaries` method. +`list_sessions_from_store()` takes a fast path when a store offers it — +**one** batch read for all summaries plus a cheap `list_sessions()` to +gap-fill stale or missing entries — instead of falling back to **N** +per-session `load()` calls (bounded at 16 concurrent) to recompute the +summary on the fly. For projects with many sessions or remote-backend +latency, that's the difference between one round trip and dozens. + +The summary itself is computed by the SDK's +`fold_session_summary` (read inside `append()` and written back as one +opaque `data` blob); the adapter never interprets the contents. Note +that `fold_session_summary` lives under the SDK's `_internal` package — +treat it as a public surface for adapters specifically (the +`SessionStore` protocol references it) but keep the import in one +place so a future relocation is a single-line fix. + +### Concurrency + +Per the `SessionStore.list_session_summaries` contract, sidecar updates +inside `append()` must be serialized when calls can race for the same +session. The adapter holds a per-session `asyncio.Lock` keyed by +`(project_key, session_id)` for the duration of the read-fold-write. + +### Retention + +This adapter never deletes documents on its own. Add a TTL index on +`mtime` (the entries collection's `mtime` is epoch ms; convert to seconds +or use a `Date` field instead) or schedule a `delete_many({"mtime": {"$lt": +cutoff}})` to expire transcripts according to your compliance requirements. + +`delete()` is implemented (cascades to subpath documents and the summary +sidecar) but is only called when you invoke `delete_session_via_store()` +from the SDK. + +Local-disk transcripts under `CLAUDE_CONFIG_DIR` are swept independently by +the CLI's `cleanupPeriodDays` setting. + +### Resume from MongoDB + +```python +async for message in query( + prompt="Continue where we left off", + options=ClaudeAgentOptions( + session_store=store, + resume="previous-session-id", + ), +): + ... +``` + +### Live MongoDB end-to-end + +There is no in-process MongoDB mock that faithfully exercises aggregation +and `distinct`, so the MongoDB tests run **live-only**. They skip +automatically unless `SESSION_STORE_MONGODB_URL` is set: + +```bash +docker run -d -p 27017:27017 mongo:latest + +SESSION_STORE_MONGODB_URL=mongodb://localhost:27017 \ + pytest tests/test_example_mongodb_session_store.py -v +``` + +Each run uses a random database name and drops it on teardown. + +This mirrors the `MongoDBSessionStore` reference adapter in the TypeScript +SDK's [`examples/session-stores/mongodb/`](https://github.com/anthropics/claude-agent-sdk-typescript/tree/main/examples/session-stores/mongodb). diff --git a/examples/session_stores/mongodb_session_store.py b/examples/session_stores/mongodb_session_store.py new file mode 100644 index 00000000..babb8685 --- /dev/null +++ b/examples/session_stores/mongodb_session_store.py @@ -0,0 +1,361 @@ +"""MongoDB-backed :class:`~claude_agent_sdk.SessionStore` reference adapter. + +This is a **reference implementation** demonstrating that the +:class:`~claude_agent_sdk.SessionStore` protocol generalizes to a document +store. It is not shipped as part of the SDK; copy it into your project and +adapt as needed (add migrations, sharding, retention sweeps, etc.). This +mirrors the ``MongoDBSessionStore`` reference implementation from the +TypeScript SDK. + +Requires ``pymongo>=4.13`` (the stable async API). Install with:: + + pip install pymongo + +Usage:: + + from pymongo import AsyncMongoClient + from claude_agent_sdk import ClaudeAgentOptions, query + + from mongodb_session_store import MongoDBSessionStore + + client = AsyncMongoClient("mongodb://localhost:27017") + store = MongoDBSessionStore(client=client, db_name="claude") + await store.create_schema() # one-time, idempotent + + async for message in query( + prompt="Hello!", + options=ClaudeAgentOptions(session_store=store), + ): + ... # messages are mirrored to MongoDB as they stream + +Schema +------ +Two collections share a single database: + +``claude_session_entries`` — one document per JSONL entry:: + + { + _id: ObjectId, # ordering key (server-assigned) + project_key: str, + session_id: str, + subpath: str, # "" sentinel for main transcript + entry: , + mtime: int, # Unix epoch ms, write-time stamp + } + +``claude_session_summaries`` — one document per main session, maintained +incrementally inside :meth:`MongoDBSessionStore.append` via +:func:`~claude_agent_sdk.fold_session_summary`:: + + { + _id: {project_key: str, session_id: str}, + mtime: int, # Unix epoch ms (same clock as entries) + data: , + } + +The empty string is the ``subpath`` sentinel for the main transcript so the +``(project_key, session_id, subpath)`` triple is total (mirrors the Postgres +adapter's convention). + +Concurrency +----------- +Per the :meth:`SessionStore.list_session_summaries` contract, stores +maintaining sidecars inside ``append()`` must serialize the read-fold-write +when ``append()`` calls can race for the same session. This adapter holds a +per-session ``asyncio.Lock`` keyed by ``(project_key, session_id)`` for the +duration of the summary update. The SDK's own ``TranscriptMirrorBatcher`` +already sequences appends per session within one process, but a user could +share one store instance across multiple concurrent batchers — the lock keeps +the fold deterministic in that case. + +Retention +--------- +This adapter never deletes documents on its own. Add a TTL index on ``mtime`` +or a scheduled ``delete_many({"mtime": {"$lt": cutoff}})`` to expire +transcripts according to your compliance requirements. Local-disk +transcripts under ``CLAUDE_CONFIG_DIR`` are swept independently by the CLI's +``cleanupPeriodDays`` setting. +""" + +from __future__ import annotations + +import asyncio +import re +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from claude_agent_sdk import ( + SessionKey, + SessionListSubkeysKey, + SessionStore, + SessionStoreEntry, + SessionStoreListEntry, + SessionSummaryEntry, +) + +# fold_session_summary lives under _internal but is treated as a public +# surface for SessionStore adapters (the protocol's docstrings reference it). +# Keep the import in one place so a future relocation is a single-line fix. +from claude_agent_sdk._internal.session_summary import fold_session_summary + +if TYPE_CHECKING: + from pymongo import AsyncMongoClient + from pymongo.asynchronous.collection import AsyncCollection + from pymongo.asynchronous.database import AsyncDatabase + +#: Conservative collection-name guard. Mongo allows ``.`` for namespacing but +#: ``$`` and null bytes are invalid; reject anything that isn't a plain +#: ``[A-Za-z_][A-Za-z0-9_.]*`` to head off injection-like footguns. +_IDENT_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_.]*$") + +#: Sentinel used in entry documents to mark the main transcript. The SDK +#: never emits an empty subpath; treating ``key.get("subpath") or ""`` as the +#: sentinel keeps the Mongo query and Postgres adapter aligned. +_MAIN: str = "" + + +@dataclass +class MongoDBSessionStoreOptions: + """Configuration for :class:`MongoDBSessionStore`.""" + + client: AsyncMongoClient[dict[str, Any]] + """Pre-configured ``pymongo.AsyncMongoClient``. Caller controls URI, + auth, TLS, pool sizing, server selection, etc.""" + + db_name: str | None = None + """Database name. Falls back to the client's default database (i.e. the + one named in the connection URI) when ``None``.""" + + entries_collection: str = "claude_session_entries" + """Collection name for transcript entries. Must match + ``[A-Za-z_][A-Za-z0-9_.]*``.""" + + summaries_collection: str = "claude_session_summaries" + """Collection name for the per-session summary sidecar. Must match + ``[A-Za-z_][A-Za-z0-9_.]*``.""" + + +class MongoDBSessionStore(SessionStore): + """MongoDB-backed :class:`~claude_agent_sdk.SessionStore`. + + One document per transcript entry; ordering via the server-assigned + ``_id`` (``ObjectId``). ``append()`` is a single ``insert_many``; + ``load()`` is ``find().sort("_id", 1)``. + + Args: + client: Pre-configured ``pymongo.AsyncMongoClient``. + db_name: Database name (default: the client's default DB). + entries_collection: Collection for entry documents + (default ``"claude_session_entries"``). + summaries_collection: Collection for summary sidecars + (default ``"claude_session_summaries"``). + options: Alternative to positional args; takes precedence if given. + """ + + def __init__( + self, + client: AsyncMongoClient[dict[str, Any]] | None = None, + db_name: str | None = None, + entries_collection: str = "claude_session_entries", + summaries_collection: str = "claude_session_summaries", + *, + options: MongoDBSessionStoreOptions | None = None, + ) -> None: + if options is not None: + client = options.client + db_name = options.db_name + entries_collection = options.entries_collection + summaries_collection = options.summaries_collection + if client is None: + raise ValueError("MongoDBSessionStore requires 'client'") + for label, name in ( + ("entries_collection", entries_collection), + ("summaries_collection", summaries_collection), + ): + if not _IDENT_RE.match(name): + raise ValueError(f"{label} {name!r} must match [A-Za-z_][A-Za-z0-9_.]*") + + self._db: AsyncDatabase[dict[str, Any]] = ( + client[db_name] if db_name is not None else client.get_default_database() + ) + self._entries: AsyncCollection[dict[str, Any]] = self._db[entries_collection] + self._summaries: AsyncCollection[dict[str, Any]] = self._db[ + summaries_collection + ] + # Per-session locks for the read-fold-write summary update. Keys are + # (project_key, session_id); locks are created lazily and never + # garbage-collected — this is reference code, not a long-running + # service. + self._summary_locks: dict[tuple[str, str], asyncio.Lock] = {} + + def _summary_lock(self, key: SessionKey) -> asyncio.Lock: + slot = (key["project_key"], key["session_id"]) + lock = self._summary_locks.get(slot) + if lock is None: + lock = asyncio.Lock() + self._summary_locks[slot] = lock + return lock + + # ------------------------------------------------------------------ + # Schema + # ------------------------------------------------------------------ + + async def create_schema(self) -> None: + """Create the indexes if absent. Idempotent. + + Call once at startup (or run the equivalent migration out-of-band). + Each index is independently named so re-running the call is a no-op + in the steady state. + """ + await self._entries.create_index( + [("project_key", 1), ("session_id", 1), ("subpath", 1), ("_id", 1)], + name="key_idx", + ) + await self._entries.create_index( + [("project_key", 1), ("subpath", 1), ("mtime", -1)], + name="sessions_idx", + ) + await self._summaries.create_index( + [("_id.project_key", 1), ("mtime", -1)], + name="summaries_idx", + ) + + # ------------------------------------------------------------------ + # SessionStore protocol + # ------------------------------------------------------------------ + + async def append(self, key: SessionKey, entries: list[SessionStoreEntry]) -> None: + if not entries: + return + subpath = key.get("subpath") or _MAIN + now = int(time.time() * 1000) + docs: list[dict[str, Any]] = [ + { + "project_key": key["project_key"], + "session_id": key["session_id"], + "subpath": subpath, + "entry": dict(entry), + "mtime": now, + } + for entry in entries + ] + # ordered=True preserves intra-batch order; the server-assigned + # ObjectId is monotonic per writer so inter-batch order is preserved + # too without an explicit sequence column. + await self._entries.insert_many(docs, ordered=True) + + # Subagent transcripts must NOT contribute to the main session's + # summary — guard before the fold (per fold_session_summary docs). + if subpath != _MAIN: + return + + compound_id = { + "project_key": key["project_key"], + "session_id": key["session_id"], + } + async with self._summary_lock(key): + prev_doc = await self._summaries.find_one({"_id": compound_id}) + prev: SessionSummaryEntry | None = ( + { + "session_id": prev_doc["_id"]["session_id"], + "mtime": int(prev_doc["mtime"]), + "data": prev_doc["data"], + } + if prev_doc is not None + else None + ) + folded = fold_session_summary(prev, key, entries) + new_doc = { + "_id": compound_id, + "mtime": now, + "data": folded["data"], + } + await self._summaries.replace_one( + {"_id": compound_id}, new_doc, upsert=True + ) + + async def load(self, key: SessionKey) -> list[SessionStoreEntry] | None: + cursor = self._entries.find( + { + "project_key": key["project_key"], + "session_id": key["session_id"], + "subpath": key.get("subpath") or _MAIN, + } + ).sort("_id", 1) + docs = await cursor.to_list(length=None) + if not docs: + return None + return [d["entry"] for d in docs] + + async def list_sessions(self, project_key: str) -> list[SessionStoreListEntry]: + # ``aggregate()`` is itself awaitable in pymongo's async API (returns + # a cursor); ``find()`` is not (returns the cursor synchronously). + cursor = await self._entries.aggregate( + [ + {"$match": {"project_key": project_key, "subpath": _MAIN}}, + { + "$group": { + "_id": "$session_id", + "mtime": {"$max": "$mtime"}, + } + }, + ] + ) + rows = await cursor.to_list(length=None) + return [{"session_id": str(r["_id"]), "mtime": int(r["mtime"])} for r in rows] + + async def list_session_summaries( + self, project_key: str + ) -> list[SessionSummaryEntry]: + cursor = self._summaries.find({"_id.project_key": project_key}) + docs = await cursor.to_list(length=None) + return [ + { + "session_id": d["_id"]["session_id"], + "mtime": int(d["mtime"]), + "data": d["data"], + } + for d in docs + ] + + async def delete(self, key: SessionKey) -> None: + subpath = key.get("subpath") + if subpath: + # Targeted: remove only this subpath's entries; do NOT touch the + # summary sidecar (which represents the main transcript). + await self._entries.delete_many( + { + "project_key": key["project_key"], + "session_id": key["session_id"], + "subpath": subpath, + } + ) + return + # Cascade: main + every subpath under (project_key, session_id), + # plus the summary sidecar. + await self._entries.delete_many( + { + "project_key": key["project_key"], + "session_id": key["session_id"], + } + ) + await self._summaries.delete_one( + { + "_id": { + "project_key": key["project_key"], + "session_id": key["session_id"], + } + } + ) + + async def list_subkeys(self, key: SessionListSubkeysKey) -> list[str]: + result = await self._entries.distinct( + "subpath", + { + "project_key": key["project_key"], + "session_id": key["session_id"], + "subpath": {"$ne": _MAIN}, + }, + ) + return list(result) diff --git a/pyproject.toml b/pyproject.toml index 6ea41b0c..52d901af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ examples = [ "redis>=4.2.0", "fakeredis>=2.20.0", "asyncpg>=0.27.0", + "pymongo>=4.13", ] [project.urls] diff --git a/tests/test_example_mongodb_session_store.py b/tests/test_example_mongodb_session_store.py new file mode 100644 index 00000000..c669ef77 --- /dev/null +++ b/tests/test_example_mongodb_session_store.py @@ -0,0 +1,359 @@ +"""Live-MongoDB tests for the example ``MongoDBSessionStore`` adapter. + +There is no in-process MongoDB mock that faithfully exercises aggregation +and ``distinct``, so this module is **live-only**: it skips unless +``SESSION_STORE_MONGODB_URL`` is set. Each run uses a random database name +and drops it on teardown. + +Run locally:: + + docker run -d -p 27017:27017 mongo:latest + SESSION_STORE_MONGODB_URL=mongodb://localhost:27017 \\ + pytest tests/test_example_mongodb_session_store.py -v +""" + +from __future__ import annotations + +import importlib.util +import itertools +import json +import os +import sys +import uuid +from collections.abc import AsyncIterator +from pathlib import Path + +import pytest +import pytest_asyncio + +# The example adapter and these tests are optional — skip the whole module +# if the [examples] dependency group isn't installed. +pymongo = pytest.importorskip( + "pymongo", reason="pymongo not installed (pip install .[examples])" +) + +MONGODB_URL = os.environ.get("SESSION_STORE_MONGODB_URL") +if not MONGODB_URL: + pytest.skip( + "live MongoDB e2e: set SESSION_STORE_MONGODB_URL " + "(e.g. mongodb://localhost:27017)", + allow_module_level=True, + ) + +from pymongo import AsyncMongoClient # noqa: E402 + +from claude_agent_sdk import ( # noqa: E402 + ClaudeAgentOptions, + SessionStore, + project_key_for_directory, +) +from claude_agent_sdk._internal.session_resume import ( # noqa: E402 + materialize_resume_session, +) +from claude_agent_sdk._internal.transcript_mirror_batcher import ( # noqa: E402 + TranscriptMirrorBatcher, +) +from claude_agent_sdk.testing import run_session_store_conformance # noqa: E402 + +# --------------------------------------------------------------------------- +# Import the example adapter without polluting sys.path globally. +# --------------------------------------------------------------------------- + +_EXAMPLE_PATH = ( + Path(__file__).parent.parent + / "examples" + / "session_stores" + / "mongodb_session_store.py" +) +_spec = importlib.util.spec_from_file_location( + "_mongodb_session_store_example", _EXAMPLE_PATH +) +assert _spec is not None and _spec.loader is not None +_module = importlib.util.module_from_spec(_spec) +sys.modules[_spec.name] = _module +_spec.loader.exec_module(_module) +MongoDBSessionStore = _module.MongoDBSessionStore +MongoDBSessionStoreOptions = _module.MongoDBSessionStoreOptions + + +SESSION_ID = "550e8400-e29b-41d4-a716-446655440000" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def client() -> AsyncIterator[AsyncMongoClient]: + c: AsyncMongoClient = AsyncMongoClient(MONGODB_URL) + try: + yield c + finally: + await c.close() + + +@pytest_asyncio.fixture +async def db_name(client: AsyncMongoClient) -> AsyncIterator[str]: + name = f"claude_test_{uuid.uuid4().hex[:8]}" + try: + yield name + finally: + await client.drop_database(name) + + +@pytest_asyncio.fixture +async def store(client: AsyncMongoClient, db_name: str) -> SessionStore: + s = MongoDBSessionStore( + options=MongoDBSessionStoreOptions(client=client, db_name=db_name) + ) + await s.create_schema() + return s + + +# --------------------------------------------------------------------------- +# Conformance harness +# --------------------------------------------------------------------------- + + +class TestConformance: + @pytest.mark.asyncio + async def test_conformance(self, client: AsyncMongoClient, db_name: str) -> None: + # The harness calls make_store() once per contract for isolation. + # Give each call its own collection pair so contracts don't see each + # other's documents; cleanup happens via the db_name teardown. + counter = itertools.count() + + async def make_store() -> SessionStore: + n = next(counter) + s = MongoDBSessionStore( + client=client, + db_name=db_name, + entries_collection=f"entries_{n}", + summaries_collection=f"summaries_{n}", + ) + await s.create_schema() + return s + + await run_session_store_conformance(make_store) + + def test_store_implements_required_methods(self, store: SessionStore) -> None: + """SessionStore is not @runtime_checkable; probe via _store_implements().""" + from claude_agent_sdk._internal.session_store_validation import ( + _store_implements, + ) + + assert _store_implements(store, "append") + assert _store_implements(store, "load") + + def test_rejects_unsafe_collection_name(self, client: AsyncMongoClient) -> None: + with pytest.raises(ValueError, match="must match"): + MongoDBSessionStore( + client=client, + entries_collection="bad; drop", + ) + with pytest.raises(ValueError, match="must match"): + MongoDBSessionStore( + client=client, + summaries_collection="bad$col", + ) + + +# --------------------------------------------------------------------------- +# Adapter-specific invariants the conformance suite cannot probe. +# --------------------------------------------------------------------------- + + +class TestAdapterSpecific: + @pytest.mark.asyncio + async def test_create_schema_is_idempotent( + self, client: AsyncMongoClient, db_name: str + ) -> None: + """Calling create_schema() twice must not raise (matches Postgres).""" + s = MongoDBSessionStore( + client=client, + db_name=db_name, + entries_collection="schema_idem_entries", + summaries_collection="schema_idem_summaries", + ) + await s.create_schema() + await s.create_schema() + await s.append({"project_key": "p", "session_id": "s"}, [{"type": "a"}]) + loaded = await s.load({"project_key": "p", "session_id": "s"}) + assert loaded == [{"type": "a"}] + + @pytest.mark.asyncio + async def test_options_kwarg_path( + self, client: AsyncMongoClient, db_name: str + ) -> None: + """The dataclass options= path must be equivalent to positional args.""" + s = MongoDBSessionStore( + options=MongoDBSessionStoreOptions( + client=client, + db_name=db_name, + entries_collection="opts_entries", + summaries_collection="opts_summaries", + ) + ) + await s.create_schema() + await s.append({"project_key": "p", "session_id": "s"}, [{"type": "a"}]) + assert await s.load({"project_key": "p", "session_id": "s"}) == [{"type": "a"}] + + @pytest.mark.asyncio + async def test_subpath_delete_preserves_summary( + self, client: AsyncMongoClient, db_name: str + ) -> None: + """Targeted subpath delete must NOT touch the main session's summary + sidecar. Only main delete (no subpath) cascades to the summary.""" + s = MongoDBSessionStore( + client=client, + db_name=db_name, + entries_collection="sub_del_entries", + summaries_collection="sub_del_summaries", + ) + await s.create_schema() + key = {"project_key": "p", "session_id": "s"} + await s.append(key, [{"type": "user", "customTitle": "title"}]) + await s.append({**key, "subpath": "subagents/agent-1"}, [{"type": "user"}]) + # Sidecar exists after the main append. + before = await s.list_session_summaries("p") + assert len(before) == 1 + # Subpath delete should leave main entries AND the sidecar intact. + await s.delete({**key, "subpath": "subagents/agent-1"}) + after = await s.list_session_summaries("p") + assert len(after) == 1 + assert after[0]["data"] == before[0]["data"] + # And then a main delete actually drops the sidecar. + await s.delete(key) + assert await s.list_session_summaries("p") == [] + + @pytest.mark.asyncio + async def test_concurrent_appends_serialize_summary_fold( + self, client: AsyncMongoClient, db_name: str + ) -> None: + """The per-session asyncio.Lock must serialize the read-fold-write so + each fold sees the previous fold's output as ``prev``. + + Without the lock, two appends carrying *different* fields (one + setting ``customTitle``, the other setting ``gitBranch``) can each + read ``prev=None``, fold against an empty summary, and write a + doc that omits the other's field. The last writer wins entirely + and one field is clobbered. With the lock, the second fold sees + the first's output and merges into it, so both fields survive. + + Repeating across many trials makes a missing lock almost certain + to produce at least one clobbered run. + """ + import asyncio + + s = MongoDBSessionStore( + client=client, + db_name=db_name, + entries_collection="conc_entries", + summaries_collection="conc_summaries", + ) + await s.create_schema() + + for trial in range(30): + key = {"project_key": "p", "session_id": f"trial-{trial}"} + + # Default-arg binds `key` at definition time so the closures + # don't capture the mutating loop variable (ruff B023). + async def with_title(k: dict[str, str] = key) -> None: + await s.append( + k, + [{"type": "user", "uuid": "t", "customTitle": "TITLE"}], + ) + + async def with_branch(k: dict[str, str] = key) -> None: + await s.append(k, [{"type": "user", "uuid": "b", "gitBranch": "main"}]) + + await asyncio.gather(with_title(), with_branch()) + + summaries = [ + s2 + for s2 in await s.list_session_summaries("p") + if s2["session_id"] == f"trial-{trial}" + ] + assert len(summaries) == 1 + data = summaries[0]["data"] + # With the lock, both fields must be present after any + # interleaving. A missing field => fold raced => regression. + assert data.get("custom_title") == "TITLE", ( + f"trial {trial}: custom_title clobbered (lock removed?) — data={data}" + ) + assert data.get("git_branch") == "main", ( + f"trial {trial}: git_branch clobbered (lock removed?) — data={data}" + ) + + +# --------------------------------------------------------------------------- +# Full round-trip: TranscriptMirrorBatcher → MongoDB → materialize_resume_session +# --------------------------------------------------------------------------- + + +class TestRoundTrip: + @pytest.mark.asyncio + async def test_mirror_then_resume( + self, + store: SessionStore, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + # Isolate ~ so auth-file copying doesn't touch the real config. + home = tmp_path / "home" + home.mkdir() + monkeypatch.setattr(Path, "home", staticmethod(lambda: home)) + monkeypatch.delenv("CLAUDE_CONFIG_DIR", raising=False) + + cwd = tmp_path / "project" + cwd.mkdir() + project_key = project_key_for_directory(cwd) + + errors: list[tuple] = [] + + async def on_error(key, msg) -> None: + errors.append((key, msg)) + + projects_dir = str(tmp_path / "config" / "projects") + batcher = TranscriptMirrorBatcher( + store=store, projects_dir=projects_dir, on_error=on_error + ) + + main_path = f"{projects_dir}/{project_key}/{SESSION_ID}.jsonl" + sub_path = f"{projects_dir}/{project_key}/{SESSION_ID}/subagents/agent-1.jsonl" + main_entries = [ + { + "type": "user", + "uuid": "u1", + "message": {"role": "user", "content": "hi"}, + }, + {"type": "assistant", "uuid": "a1", "message": {"role": "assistant"}}, + ] + sub_entries = [{"type": "user", "uuid": "su1", "isSidechain": True}] + + batcher.enqueue(main_path, main_entries) + batcher.enqueue(sub_path, sub_entries) + await batcher.flush() + assert errors == [] + + opts = ClaudeAgentOptions(cwd=cwd, session_store=store, resume=SESSION_ID) + result = await materialize_resume_session(opts) + assert result is not None + try: + assert result.resume_session_id == SESSION_ID + jsonl = ( + result.config_dir / "projects" / project_key / f"{SESSION_ID}.jsonl" + ).read_text() + assert [json.loads(line) for line in jsonl.splitlines()] == main_entries + sub_jsonl = ( + result.config_dir + / "projects" + / project_key + / SESSION_ID + / "subagents" + / "agent-1.jsonl" + ).read_text() + assert [json.loads(line) for line in sub_jsonl.splitlines()] == sub_entries + finally: + await result.cleanup()