diff --git a/src/aiperf/workers/worker.py b/src/aiperf/workers/worker.py index 695fd0045b..71aa8fcece 100644 --- a/src/aiperf/workers/worker.py +++ b/src/aiperf/workers/worker.py @@ -17,6 +17,7 @@ CacheBustTarget, CommAddress, CommandType, + ConversationBranchMode, MemoryMapFormat, MessageType, ) @@ -101,6 +102,27 @@ def _apply_cache_bust_to_system_message( return system_message +def _content_has_marker_at_edge( + content: object, marker: str, *, is_prefix: bool +) -> bool: + """Whether ``content`` already carries ``marker`` at the prefix/suffix edge. + + The injection helpers run once per credit and several paths mutate a turn + object shared across the session's turns (delta-mode ``turn_list[0]``, and + the unconditional every-credit first-user mark used for seeded resumes), so + re-injecting the constant per-session marker must not stack it. This check + is exact (the per-session marker is constant; a fresh recycled play sees + pristine content and injects its own marker). Handles plain-string and + OpenAI multimodal list-of-parts content. + """ + if isinstance(content, str): + return content.startswith(marker) if is_prefix else content.endswith(marker) + if isinstance(content, list) and content: + marker_part = {"type": "text", "text": marker.strip()} + return content[0 if is_prefix else -1] == marker_part + return False + + def _inject_marker_into_raw_messages( raw_messages: list[dict], marker: str, *, is_prefix: bool ) -> None: @@ -109,7 +131,8 @@ def _inject_marker_into_raw_messages( No-op when raw_messages is empty or the first message is not a system role. For multimodal content (``content`` is a list of parts), the marker is inserted as a new ``{"type": "text", "text": marker}`` part at the start - (prefix) or end (suffix) of the parts list. + (prefix) or end (suffix) of the parts list. Idempotent via + :func:`_content_has_marker_at_edge`. """ if not raw_messages or not marker: return @@ -117,6 +140,8 @@ def _inject_marker_into_raw_messages( if not isinstance(first, dict) or first.get("role") != "system": return content = first.get("content", "") + if _content_has_marker_at_edge(content, marker, is_prefix=is_prefix): + return if isinstance(content, str): raw_messages[0] = { **first, @@ -142,13 +167,18 @@ def _inject_marker_into_first_user_turn( No-op when raw_messages is empty. For multimodal content (``content`` is a list of parts), the marker is inserted as a new ``{"type": "text", "text": marker}`` part at the start (prefix) or end - (suffix) of the parts list. + (suffix) of the parts list. Idempotent via + :func:`_content_has_marker_at_edge` — FIRST_TURN_* injection runs every + credit (to mark seeded turn 0 on mid-trajectory resumes), so repeated calls + on the same shared turn must not stack the marker. """ if not raw_messages or not marker: return for idx, msg in enumerate(raw_messages): if isinstance(msg, dict) and msg.get("role") == "user": content = msg.get("content", "") + if _content_has_marker_at_edge(content, marker, is_prefix=is_prefix): + return if isinstance(content, str): raw_messages[idx] = { **msg, @@ -231,6 +261,8 @@ def _inject_marker_into_first_user_text( first.contents = [marker.strip()] return existing = first.contents[0] + if _content_has_marker_at_edge(existing, marker, is_prefix=is_prefix): + return first.contents[0] = (marker + existing) if is_prefix else (existing + marker) @@ -272,17 +304,16 @@ def _apply_cache_bust( ``turn_list`` where the system role lives in ``turn_list[0]`` and later deltas start with the prior assistant response). - SYSTEM_* fallback: when ``target`` is ``SYSTEM_PREFIX`` / ``SYSTEM_SUFFIX`` - and there is no system message anywhere (neither a Conversation-level - ``system_message`` nor a leading ``role=="system"`` entry in any turn's - ``raw_messages``), the marker is routed to the first user turn with the - same prefix/suffix orientation — i.e. SYSTEM_PREFIX falls back to a - first-user-turn prefix, SYSTEM_SUFFIX falls back to a first-user-turn - suffix. Without a system prompt the first user message is the prefix of - the entire wire payload, so this produces the same physical token-0 - divergence without fabricating a system role. The fallback is gated on - ``credit.turn_index == 0`` (matches FIRST_TURN_* semantics: marker only - affects the first turn's KV cache; later turns inherit). + Injection targets the *effective wire prefix* (see + :func:`_effective_prefix_turns`): the slice of ``turn_list`` that + ``build_messages`` actually emits. A ``reset_context`` turn makes + ``build_messages`` discard every prior turn, so the effective prefix begins + at the last such turn — which may sit mid-history (seeded on a resume, + never dispatched as the current turn), not just at ``turn_list[-1]``. + Marking the discarded turn 0 instead would leave the real prefix unmarked + and let recycled plays warm the server's cache on identical post-reset + bytes. ``SYSTEM_*`` targets are handled in + :func:`_apply_system_target_cache_bust`. """ marker = credit.cache_bust_marker target = credit.cache_bust_target @@ -290,38 +321,96 @@ def _apply_cache_bust( if not marker or target == CacheBustTarget.NONE: return system_message + # FORK children share the parent's KV cache by design: they seed turn_list + # from the parent (the SAME Turn objects) and must send the parent's exact + # prefix to hit its cache. The parent already injected its marker into those + # shared turns, so the child inherits it for free — re-busting here would + # diverge the child's prefix from the parent's (cache miss) AND mutate the + # parent's shared, read-only Turn objects (stacking markers). So cache-bust + # is a no-op for FORK children. SPAWN children start fresh (no shared turns) + # and root sessions own their prefix, so both are busted normally. + if ( + session.parent_correlation_id is not None + and session.branch_mode == ConversationBranchMode.FORK + ): + return system_message + is_prefix = target in ( CacheBustTarget.SYSTEM_PREFIX, CacheBustTarget.FIRST_TURN_PREFIX, ) + prefix_turns = _effective_prefix_turns(session) if target in (CacheBustTarget.SYSTEM_PREFIX, CacheBustTarget.SYSTEM_SUFFIX): - # Three sub-paths with intentionally different semantics: - # 1. Conversation-level system_message present: marker injected - # every turn (string mutation re-applied per credit). - # 2. raw_messages first dict has role=="system": marker injected - # every turn (raw mutation re-applied per credit). Under deltas - # that dict lives in turn_list[0]; under message-array it lives - # in turn_list[-1] (same single turn). - # 3. No system anywhere -> first-user-turn fallback: marker injected - # ONLY on turn_index == 0. Subsequent turns inherit via the - # inference server's prefix-cache hit, matching FIRST_TURN_* - # semantics. Re-injecting on every turn would drift token-0 on - # every credit and fragment the cache key. - if system_message is not None: - return _apply_cache_bust_to_system_message(system_message, marker, target) - raw_system = _find_first_system_message(session.turn_list) - if raw_system is not None: - _inject_marker_into_raw_messages(raw_system, marker, is_prefix=is_prefix) - elif credit.turn_index == 0: - _inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix) - return system_message + return _apply_system_target_cache_bust( + prefix_turns, + system_message=system_message, + marker=marker, + target=target, + is_prefix=is_prefix, + ) - if credit.turn_index == 0: - _inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix) + # Mark the effective prefix's opening user turn every credit (idempotent). + # Unconditional rather than turn_index==0-gated so a seeded mid-trajectory + # resume (turn_list back-filled with turns 0..k_i at a credit whose + # turn_index > 0) still marks the true wire prefix. + _inject_marker_at_first_user(prefix_turns, marker, is_prefix=is_prefix) return system_message +def _apply_system_target_cache_bust( + prefix_turns: list[Turn], + *, + system_message: str | None, + marker: str, + target: CacheBustTarget, + is_prefix: bool, +) -> str | None: + """Inject a ``SYSTEM_PREFIX`` / ``SYSTEM_SUFFIX`` marker for one credit. + + ``prefix_turns`` is the effective wire prefix slice (see + :func:`_effective_prefix_turns`). Three sub-paths: + 1. Conversation-level ``system_message`` present: marker applied every + turn (string mutation re-applied per credit). Unaffected by + ``reset_context`` — the ``system_message`` rides on ``RequestInfo`` and + is re-emitted every turn independent of ``build_messages``' reset. + 2. ``raw_messages`` first dict has ``role=="system"``: marker injected + into the first system message of the prefix slice. + 3. No system in the slice -> first-user-turn fallback: marker injected + into the first user turn every credit (idempotent), matching + ``FIRST_TURN_*`` semantics so a seeded mid-trajectory resume still + marks the prefix. + + Returns the (possibly modified) ``system_message``. + """ + if system_message is not None: + return _apply_cache_bust_to_system_message(system_message, marker, target) + raw_system = _find_first_system_message(prefix_turns) + if raw_system is not None: + _inject_marker_into_raw_messages(raw_system, marker, is_prefix=is_prefix) + else: + _inject_marker_at_first_user(prefix_turns, marker, is_prefix=is_prefix) + return system_message + + +def _effective_prefix_turns(session: UserSession) -> list[Turn]: + """The ``turn_list`` slice that forms the wire prefix for cache-bust. + + ``base_endpoint.build_messages`` restarts the message array at every + ``reset_context`` turn that carries ``raw_messages`` (discarding everything + before it), so the effective prefix begins at the *last* such turn in + ``turn_list`` — not turn 0, and not merely ``turn_list[-1]``: a reset can + sit mid-history (e.g. seeded into a mid-trajectory resume, where it is never + dispatched as the current turn). Returns the slice from that turn to the + end, or the whole ``turn_list`` when there is no reset. + """ + turns = session.turn_list + for i in range(len(turns) - 1, -1, -1): + if turns[i].reset_context and turns[i].raw_messages: + return turns[i:] + return turns + + class Worker(BaseComponentService, ProcessHealthMixin): """Worker processes credits from the TimingManager and makes API calls to inference servers. @@ -845,8 +934,7 @@ def _maybe_warn_cache_bust_silent_drop( credit: Credit, ) -> None: """Emit a one-shot warning if cache-bust was requested but had nowhere - to land on this credit (e.g. SYSTEM_* on turn>0 with no system anywhere, - or empty session.turn_list). + to land on this credit (an empty ``session.turn_list``). Rate-limited to once per worker via ``self._cache_bust_warning_shown`` — the misconfiguration is identical for every credit, so a single @@ -864,30 +952,6 @@ def _maybe_warn_cache_bust_silent_drop( f"cache-bust target={target.value} requested but session.turn_list " f"is empty — marker NOT injected (further occurrences suppressed)." ) - return - # SYSTEM_* on turn>0 with no system anywhere: the fallback is gated on - # turn_index==0 by design (see _apply_cache_bust comments), so the - # marker is intentionally NOT re-applied. Surface this once so users - # configuring cache-bust against a synthetic / no-system trace see why - # token-0 didn't drift. - if target in (CacheBustTarget.SYSTEM_PREFIX, CacheBustTarget.SYSTEM_SUFFIX): - if session.conversation.system_message is not None: - return - last_turn = session.turn_list[-1] - raw = last_turn.raw_messages - has_raw_system = bool( - raw and isinstance(raw[0], dict) and raw[0].get("role") == "system" - ) - if not has_raw_system and credit.turn_index > 0: - self._cache_bust_warning_shown = True - self.warning( - f"cache-bust target={target.value} requested but trace has no " - f"system message (neither Conversation.system_message nor " - f"raw_messages[0].role=='system'); fallback to first-user-turn " - f"only fires on turn_index==0, so subsequent turns inherit the " - f"already-prefixed prompt. This is intentional (matches " - f"FIRST_TURN_* semantics) — further occurrences suppressed." - ) async def _execute_request( self, diff --git a/tests/component_integration/test_agentic_replay_cache_bust.py b/tests/component_integration/test_agentic_replay_cache_bust.py index f9e2caafc6..f36bcfed8e 100644 --- a/tests/component_integration/test_agentic_replay_cache_bust.py +++ b/tests/component_integration/test_agentic_replay_cache_bust.py @@ -291,14 +291,12 @@ def test_agentic_replay_cache_bust_marker_in_wire_payload( across all turns of a session, distinct across sessions, and absent from the trace turn bodies. - Note on ``FIRST_TURN_*`` semantics (spec §4.5): the worker only injects - the marker at ``credit.turn_index == 0``. Agentic_replay trajectories - that resume at ``k_i > 0`` therefore never see a FIRST_TURN_* marker — - only sessions that begin at turn 0 (recycled spawns and k_i=0 - trajectories) carry one. We restrict the per-session continuity / - cross-session distinctness assertions to *marked* sessions for - FIRST_TURN_* and require at least one such marked session to exist. - SYSTEM_* applies on every turn, so marker coverage is universal. + Marker coverage is universal for every target: FIRST_TURN_* injects into + the effective wire prefix's opening user turn on every credit (including + mid-trajectory resumes at ``k_i > 0``, whose seeded turn 0 is the real + prefix), and SYSTEM_* applies every turn. So every profiling session must + carry exactly one marker — a regression of the seeded-resume fix would show + up here as an unmarked ``k_i > 0`` session. """ cmd = _build_cmd(weka_with_system_dir, cache_bust=target) result = cli.run_sync(cmd, timeout=defaults.timeout) @@ -329,10 +327,6 @@ def test_agentic_replay_cache_bust_marker_in_wire_payload( f"got {len(by_session)}: {list(by_session.keys())}" ) - is_first_turn_target = target in ( - CacheBustTarget.FIRST_TURN_PREFIX, - CacheBustTarget.FIRST_TURN_SUFFIX, - ) is_prefix_target = target in ( CacheBustTarget.SYSTEM_PREFIX, CacheBustTarget.FIRST_TURN_PREFIX, @@ -387,22 +381,19 @@ def test_agentic_replay_cache_bust_marker_in_wire_payload( ) session_rids[xcorr] = next(iter(rids_in_session)) - if is_first_turn_target: - # FIRST_TURN_* only fires when credit.turn_index == 0. With our - # 6-trace fixture + concurrency=3 + duration=8s, recycled sessions - # always start at turn 0, so at least one session must be marked. - assert len(session_rids) >= 1, ( - f"target={target}: no session received a FIRST_TURN marker. " - f"Recycled sessions begin at turn_index=0 and must inject. " - f"Total sessions={len(by_session)}, " - f"unmarked={len(sessions_without_marker)}" - ) - else: - # SYSTEM_* applies on every turn -> every session must be marked. - assert not sessions_without_marker, ( - f"target={target}: SYSTEM_* must mark every session; " - f"unmarked={sessions_without_marker}" - ) + # Every profiling session must carry a marker, for ALL targets. FIRST_TURN_* + # marks the effective prefix's opening user turn on every credit — including + # seeded mid-trajectory resumes at k_i > 0 — so an unmarked session here is a + # regression of the seeded-resume fix. SYSTEM_* applies every turn. (This + # fixture is linear, no FORK children — FORK inheritance is covered in the + # DAG cache-bust test and the worker unit tests.) + assert not sessions_without_marker, ( + f"target={target}: every session must be marked, but these were not: " + f"{sessions_without_marker}. Total sessions={len(by_session)}." + ) + assert len(session_rids) >= 1, ( + f"target={target}: no marked sessions at all (fixture/run too small?)." + ) # Cross-session distinctness: among marked sessions we want >= 2 distinct # rids whenever there are >= 2 marked sessions (which is the common case). diff --git a/tests/component_integration/test_agentic_replay_cache_bust_collision_free.py b/tests/component_integration/test_agentic_replay_cache_bust_collision_free.py index c0931ac9a3..0bf965e1d6 100644 --- a/tests/component_integration/test_agentic_replay_cache_bust_collision_free.py +++ b/tests/component_integration/test_agentic_replay_cache_bust_collision_free.py @@ -44,11 +44,14 @@ def weka_collision_fixture(tmp_path: Path) -> Path: def _build_cmd(weka_dir: Path, *, duration: int) -> str: - """Build an aiperf command tuned to drive >=50 distinct sessions. - - 4 traces x concurrency=3 plus a 6s benchmark window forces continuous - recycle of the small pool; 100+ recycles per trace are typical, which - means hundreds of x_correlation_ids each of which mints a fresh marker. + """Build an aiperf command that drives many distinct recycled sessions. + + 4 traces x concurrency=3 over a multi-second benchmark window forces + continuous recycle of the small pool, so each completed session mints a + fresh marker. The exact session count is wall-clock-dependent (it scales + with machine speed); the assertion floor below is set well under what even + a loaded machine produces so the zero-collision contract -- not throughput + -- is what the test gates on. """ return f""" aiperf profile @@ -90,10 +93,13 @@ def test_no_marker_collisions_across_large_recycle_run( Asserts (within PROFILING): 1. Every session has exactly one rid (intra-session marker continuity). 2. ``len(set(rids)) == len(rids)`` across all sessions (zero collisions). - 3. >=50 distinct rids observed (smoke check that the run was big enough - to be a meaningful uniqueness test). + 3. >=20 distinct rids observed -- a non-vacuity floor, set well below the + session count a loaded machine produces so it does not flake on + throughput. The zero-collision check (2) is the real regression bar: + the pre-fix 33% collision rate is caught with ~99.9% probability even + at 20 sessions, so this floor does not weaken detection. """ - cmd = _build_cmd(weka_collision_fixture, duration=6) + cmd = _build_cmd(weka_collision_fixture, duration=10) result = cli.run_sync(cmd, timeout=defaults.timeout) assert result.exit_code == 0, ( @@ -129,9 +135,10 @@ def test_no_marker_collisions_across_large_recycle_run( ) session_rids.append(next(iter(rids_in_session))) - assert len(session_rids) >= 50, ( - f"Need >=50 sessions for a meaningful uniqueness test; " - f"got {len(session_rids)}. Increase duration or shrink fixture." + assert len(session_rids) >= 20, ( + f"Need >=20 sessions for a non-vacuous uniqueness test; " + f"got {len(session_rids)}. Increase --benchmark-duration or shrink the " + f"fixture if a slower machine is under-producing sessions." ) # The hard contract: zero duplicates across the entire run. diff --git a/tests/component_integration/test_agentic_replay_cache_bust_reset.py b/tests/component_integration/test_agentic_replay_cache_bust_reset.py new file mode 100644 index 0000000000..e72e408882 --- /dev/null +++ b/tests/component_integration/test_agentic_replay_cache_bust_reset.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end: cache-bust markers survive ``reset_context`` turns. + +Validates the reset-semantics fix on the actual wire. A weka trace fixture is +crafted so that turn 1 is a non-monotonic LCP cut -> the loader emits it with +``reset_context=True``. Under ``reset_context`` the endpoint's ``build_messages`` +discards the accumulated turn-0 prefix and restarts the wire payload from the +reset turn, so the reset turn becomes a brand-new prefix. + +The test asserts two things together: + +1. ACTUAL MARKERS: every profiling session carries exactly one ``[rid:HEX]`` + cache-bust marker in its wire payload (markers are really injected, not + silently dropped), and no wire message ever carries more than one (no + stacking). + +2. RESET SEMANTICS: the reset-turn requests (``turn_index == 1`` for these + two-turn traces) carry the marker on their first user message -- the new + post-reset prefix. Before the reset fix the marker landed on the discarded + turn 0, leaving the reset turn's wire prefix unmarked; this asserts the + marker follows the effective prefix across the cut. + +A deterministic loader-level check (no benchmark) first proves the fixture +genuinely produces a ``reset_context`` turn, so the end-to-end assertions are +exercising the reset path rather than passing vacuously. +""" + +from __future__ import annotations + +import json +import re +from collections import defaultdict +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from aiperf.dataset.loader.weka_trace import WekaTraceLoader +from tests.component_integration.conftest import ComponentIntegrationTestDefaults +from tests.harness.utils import AIPerfCLI + +pytestmark = pytest.mark.component_integration + +defaults = ComponentIntegrationTestDefaults +_MODEL = "claude-opus-4-5-20251101" +_BLOCK_SIZE = 16 +_RID_RE = re.compile(r"\[rid:[0-9a-f]{12}\]") + + +def _req(t: float, hash_ids: list[int], in_tokens: int) -> dict: + return { + "t": t, + "type": "n", + "model": _MODEL, + "in": in_tokens, + "out": 8, + "hash_ids": hash_ids, + "input_types": ["text"], + "output_types": ["text"], + "stop": "end_turn", + "api_time": 0.05, + "think_time": 0.0, + } + + +def _write_reset_fixture(target_dir: Path, *, num_traces: int = 6) -> Path: + """Write weka traces whose turn 1 is a non-monotonic LCP cut (reset_context). + + Turn 0 establishes a single 5-block user segment ``[1,2,3,4,5]``. Turn 1's + ``hash_ids`` share only ``[1,2]`` (LCP=2), landing inside the previously + emitted segment -> the reconstructor records a disturbance on an emitted + segment and flags ``reset_context=True`` (see ConversationReconstructor. + turn_delta case 3). Two turns => agentic_replay picks k_i=0 and resumes + profiling at turn 1, so the reset turn is the profiling turn. + """ + target_dir.mkdir(parents=True, exist_ok=True) + for n in range(1, num_traces + 1): + # Distinct post-cut blocks per trace keep hash_ids varied across traces. + cut_a, cut_b = 100 + 2 * n, 101 + 2 * n + trace = { + "id": f"reset_trace_{n:02d}", + "models": [_MODEL], + "block_size": _BLOCK_SIZE, + "hash_id_scope": "local", + "tool_tokens": 0, + "system_tokens": 0, + "requests": [ + _req(0.0, [1, 2, 3, 4, 5], 5 * _BLOCK_SIZE), + _req(1.0, [1, 2, cut_a, cut_b], 4 * _BLOCK_SIZE), + ], + } + (target_dir / f"reset_trace_{n:02d}.json").write_text(json.dumps(trace)) + return target_dir + + +def _loader_user_config() -> MagicMock: + uc = MagicMock() + uc.input.random_seed = 0 + uc.input.fixed_schedule_start_offset = None + uc.input.fixed_schedule_end_offset = None + uc.input.ignore_trace_delays = False + uc.input.use_think_time_only = False + uc.input.synthesis.max_isl = None + uc.input.synthesis.max_osl = None + uc.input.max_context_length = None + uc.input.synthesis.should_synthesize.return_value = False + uc.input.prompt.input_tokens.block_size = None + uc.tokenizer.trust_remote_code = False + uc.tokenizer.revision = None + uc.tokenizer.name = "test-tok" + uc.endpoint.model_names = [_MODEL] + uc.loadgen.inter_turn_delay_cap_seconds = None + return uc + + +def _fixture_produces_reset(trace_file: Path) -> bool: + """Deterministically load one trace file and report whether any turn carries + ``reset_context=True`` (proves the fixture exercises the reset path).""" + loader = WekaTraceLoader( + filename=str(trace_file), user_config=_loader_user_config() + ) + loader.synthesize_prompts_from_hash_ids = lambda rs: {r.key: "p" for r in rs} + pg = MagicMock() + pg._corpus_size = 10000 + pg._tokenized_corpus = list(range(10000)) + pg.tokenizer.decode = lambda tokens: f"decoded-{len(tokens)}" + loader.prompt_generator = pg + loader._tokenizer_name = "t" + loader._trust_remote_code = False + loader._tokenizer_revision = None + loader._block_size = _BLOCK_SIZE + + convs = loader.convert_to_conversations(loader.load_dataset()) + return any(t.reset_context for c in convs for t in c.turns) + + +def _build_cmd(weka_dir: Path) -> str: + return f""" + aiperf profile + --model claude-haiku-4-5-20251001 + --model {_MODEL} + --endpoint-type chat + --streaming + --custom-dataset-type weka_trace + --input-file {weka_dir} + --no-fixed-schedule + --benchmark-duration 8 + --concurrency 3 + --random-seed 42 + --tokenizer {defaults.tokenizer} + --extra-inputs ignore_eos:true + --workers-max {defaults.workers_max} + --ui {defaults.ui} + --scenario inferencex-agentx-mvp + --unsafe-override + --cache-bust first_turn_prefix + --export-level raw + """ + + +def _payload_dict(record) -> dict: + if record.payload is not None: + return record.payload + if record.payload_bytes is not None: + return json.loads(record.payload_bytes) + return {} + + +def _first_user_content(payload: dict) -> str | None: + for msg in payload.get("messages", []): + if isinstance(msg, dict) and msg.get("role") == "user": + content = msg.get("content") + return content if isinstance(content, str) else None + return None + + +def _max_rids_in_any_message(payload: dict) -> int: + counts = [0] + for msg in payload.get("messages", []): + if isinstance(msg, dict) and isinstance(msg.get("content"), str): + counts.append(len(_RID_RE.findall(msg["content"]))) + return max(counts) + + +@pytest.fixture +def weka_reset_dir(tmp_path: Path) -> Path: + return _write_reset_fixture(tmp_path / "weka_reset", num_traces=6) + + +def test_reset_fixture_actually_produces_reset_context(weka_reset_dir: Path): + """Guard: the crafted fixture genuinely yields a reset_context turn, so the + end-to-end test below is exercising the reset path (not passing vacuously).""" + trace_file = next(weka_reset_dir.glob("*.json")) + assert _fixture_produces_reset(trace_file), ( + "fixture no longer triggers reset_context — the LCP-cut shape or the " + "reconstructor's reset rule changed; the end-to-end test would be vacuous" + ) + + +def test_cache_bust_marker_survives_reset_context( + cli: AIPerfCLI, + weka_reset_dir: Path, +): + cmd = _build_cmd(weka_reset_dir) + result = cli.run_sync(cmd, timeout=defaults.timeout) + + assert result.exit_code == 0, ( + f"CLI run failed: stderr=\n{result.stderr}" + f"\nlog tail=\n{(result.log or '')[-2000:]}" + ) + assert result.raw_records is not None and len(result.raw_records) > 0, ( + "raw records JSONL must be present and non-empty" + ) + + by_session: dict[str, list] = defaultdict(list) + reset_turn_records: list = [] + for rec in result.raw_records: + if rec.metadata.benchmark_phase != "profiling": + continue + # No wire message may carry a stacked marker, ever. + assert _max_rids_in_any_message(_payload_dict(rec)) <= 1, ( + f"stacked rid markers in a single message: " + f"conv={rec.metadata.conversation_id} ti={rec.metadata.turn_index}" + ) + xcorr = rec.metadata.x_correlation_id + if xcorr is not None: + by_session[xcorr].append(rec) + # turn 1 is the reset turn for these two-turn traces. + if rec.metadata.turn_index == 1: + reset_turn_records.append(rec) + + # ACTUAL MARKERS: every profiling session carries exactly one rid. + session_rids: list[str] = [] + for xcorr, records in by_session.items(): + rids: set[str] = set() + for rec in records: + fu = _first_user_content(_payload_dict(rec)) or "" + m = _RID_RE.search(fu) + if m: + rids.add(m.group(0)) + assert len(rids) == 1, ( + f"session={xcorr}: expected exactly one rid across " + f"{len(records)} turns; got {rids}" + ) + session_rids.append(next(iter(rids))) + + assert len(session_rids) >= 2, ( + f"need >=2 marked sessions for a meaningful run; got {len(session_rids)}" + ) + # Cross-session distinctness (collision-free per play/lane/trace). + assert len(set(session_rids)) == len(session_rids), ( + f"marker collision across sessions: {session_rids}" + ) + + # RESET SEMANTICS: the reset-turn (turn 1) requests carry the marker on the + # post-reset prefix. Without the reset fix the marker would land on the + # discarded turn 0 and these would be unmarked. + assert reset_turn_records, "expected at least one reset-turn (turn_index==1) record" + for rec in reset_turn_records: + fu = _first_user_content(_payload_dict(rec)) + assert fu is not None and _RID_RE.search(fu), ( + f"reset turn wire prefix is unmarked (marker lost across reset): " + f"conv={rec.metadata.conversation_id} first_user={fu!r}" + ) diff --git a/tests/integration/test_agentic_replay_cache_bust.py b/tests/integration/test_agentic_replay_cache_bust.py new file mode 100644 index 0000000000..8f32187433 --- /dev/null +++ b/tests/integration/test_agentic_replay_cache_bust.py @@ -0,0 +1,309 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""REAL integration tests for cache-bust markers through the full aiperf +subprocess (real ZMQ + real workers + real HTTP against the mock server). + +The ``component_integration`` suite covers the same behaviors in-process via +FakeCommunication; these run the actual subprocess so they also exercise the +serialization and multiprocess session/worker paths. Covered here: + +- Per-target marker injection on the wire (FIRST_TURN_* on the first user turn, + SYSTEM_* on the system message), one marker per session, distinct across + sessions (collision-free minting), never stacked. +- NONE target: no markers anywhere. +- SPAWN subagent fan-out: subagent children are independently busted with their + OWN marker, distinct from the parent root's (the reachable production fan-out + + cache-bust path; contrast with FORK, which inherits and is unreachable here). +""" + +from __future__ import annotations + +import json +import re +from collections import defaultdict +from pathlib import Path + +import pytest + +from aiperf.common.enums import CacheBustTarget +from tests.harness.utils import AIPerfCLI, AIPerfMockServer + +_OPUS = "claude-opus-4-5-20251101" +_HAIKU = "claude-haiku-4-5-20251001" +_TOKENIZER = "openai/gpt-oss-120b" # pre-cached + offline in integration conftest +_BLOCK_SIZE = 16 +_RID_RE = re.compile(r"\[rid:[0-9a-f]{12}\]") + + +# --- fixtures --------------------------------------------------------------- + + +def _req(t: float, hash_ids: list[int], in_tokens: int, *, model: str = _OPUS) -> dict: + return { + "t": t, + "type": "n", + "model": model, + "in": in_tokens, + "out": 8, + "hash_ids": hash_ids, + "input_types": ["text"], + "output_types": ["text"], + "stop": "end_turn", + "api_time": 0.05, + "think_time": 0.0, + } + + +def _write_linear_fixture(target_dir: Path, *, num_traces: int = 6) -> Path: + """Linear multi-turn weka traces with a system prefix (tool+system tokens) + so SYSTEM_* targets have a system-role message to inject into.""" + target_dir.mkdir(parents=True, exist_ok=True) + for n in range(1, num_traces + 1): + requests = [] + for k in range(max(2, n)): # >=2 turns so agentic_replay can split + user_blocks = k + 1 + in_tokens = (1 + user_blocks) * _BLOCK_SIZE + 4 + hash_ids = list(range(1, 1 + 1 + user_blocks)) # 1 sys + N user + requests.append(_req(k * 1.0, hash_ids, in_tokens)) + trace = { + "id": f"lin_trace_{n:02d}", + "models": [_OPUS], + "block_size": _BLOCK_SIZE, + "hash_id_scope": "local", + "tool_tokens": 8, + "system_tokens": 8, + "requests": requests, + } + (target_dir / f"lin_trace_{n:02d}.json").write_text(json.dumps(trace)) + return target_dir + + +def _write_subagent_fixture(target_dir: Path, *, num_traces: int = 5) -> Path: + """Weka traces each carrying a ``type:subagent`` entry -> a SPAWN child.""" + target_dir.mkdir(parents=True, exist_ok=True) + for i in range(1, num_traces + 1): + base = i * 10 + trace = { + "id": f"sa_trace_{i:02d}", + "models": [_OPUS], + "block_size": 64, + "hash_id_scope": "local", + "requests": [ + _req(0.0, [base + 1, base + 2, base + 3], 200), + { + "t": 2.0, + "type": "subagent", + "agent_id": f"agent_{i:03d}", + "subagent_type": "Explore", + "duration_ms": 3000, + "total_tokens": 500, + "tool_use_count": 2, + "status": "completed", + "requests": [ + _req(0.0, [base + 100, base + 101], 100, model=_HAIKU), + ], + "models": [_HAIKU], + "tool_tokens": 20, + "system_tokens": 10, + }, + _req(6.0, [base + 1, base + 2, base + 3, base + 4, base + 5], 400), + ], + } + # subagent inner request stop must be a tool-using stop on the parent turn 0 + trace["requests"][0]["stop"] = "tool_use" + trace["requests"][2]["input_types"] = ["tool_result"] + (target_dir / f"sa_trace_{i:02d}.json").write_text(json.dumps(trace)) + return target_dir + + +# --- helpers ---------------------------------------------------------------- + + +def _payload_dict(record) -> dict: + if record.payload is not None: + return record.payload + if record.payload_bytes is not None: + return json.loads(record.payload_bytes) + return {} + + +def _content_of(payload: dict, role: str) -> str | None: + for msg in payload.get("messages", []): + if isinstance(msg, dict) and msg.get("role") == role: + c = msg.get("content") + return c if isinstance(c, str) else None + return None + + +def _carrier_text(payload: dict, target: CacheBustTarget) -> str | None: + if target in (CacheBustTarget.SYSTEM_PREFIX, CacheBustTarget.SYSTEM_SUFFIX): + return _content_of(payload, "system") + return _content_of(payload, "user") + + +def _max_rids_in_any_message(payload: dict) -> int: + counts = [0] + for msg in payload.get("messages", []): + if isinstance(msg, dict) and isinstance(msg.get("content"), str): + counts.append(len(_RID_RE.findall(msg["content"]))) + return max(counts) + + +def _build_cmd(weka_dir: Path, url: str, cache_bust: str) -> str: + return f""" + aiperf profile \ + --model {_HAIKU} \ + --model {_OPUS} \ + --url {url} \ + --endpoint-type chat \ + --streaming \ + --custom-dataset-type weka_trace \ + --input-file {weka_dir} \ + --no-fixed-schedule \ + --benchmark-duration 8 \ + --concurrency 3 \ + --random-seed 42 \ + --tokenizer {_TOKENIZER} \ + --extra-inputs ignore_eos:true \ + --workers-max 2 \ + --scenario inferencex-agentx-mvp \ + --unsafe-override \ + --cache-bust {cache_bust} \ + --export-level raw \ + --ui simple + """ + + +def _profiling_records(result) -> list: + return [ + r + for r in (result.raw_records or []) + if r.metadata.benchmark_phase == "profiling" + ] + + +# --- tests ------------------------------------------------------------------ + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "target", + [ + CacheBustTarget.FIRST_TURN_PREFIX, + CacheBustTarget.FIRST_TURN_SUFFIX, + CacheBustTarget.SYSTEM_PREFIX, + CacheBustTarget.SYSTEM_SUFFIX, + ], +) +async def test_marker_in_wire_payload_real_subprocess( + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, + target: CacheBustTarget, +): + """Every profiling session carries exactly one marker in the target's + carrier, distinct across sessions (collision-free), never stacked.""" + weka_dir = _write_linear_fixture(tmp_path / f"lin_{target.value}") + result = await cli.run( + _build_cmd(weka_dir, aiperf_mock_server.url, target.value), timeout=300.0 + ) + records = _profiling_records(result) + assert records, f"no profiling records\n{(result.log or '')[-1500:]}" + + by_session: dict[str, set[str]] = defaultdict(set) + for rec in records: + payload = _payload_dict(rec) + assert _max_rids_in_any_message(payload) <= 1, ( + f"target={target}: stacked markers in conv={rec.metadata.conversation_id}" + ) + carrier = _carrier_text(payload, target) or "" + m = _RID_RE.search(carrier) + xcorr = rec.metadata.x_correlation_id + if xcorr is not None and m: + by_session[xcorr].add(m.group(0)) + + assert by_session, f"target={target}: no markers found in any session carrier" + for xcorr, rids in by_session.items(): + assert len(rids) == 1, f"target={target} session={xcorr}: multiple rids {rids}" + all_rids = [next(iter(v)) for v in by_session.values()] + assert len(set(all_rids)) == len(all_rids), ( + f"target={target}: marker collision across sessions: {all_rids}" + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_none_target_has_no_markers_real_subprocess( + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, +): + """With --cache-bust none, no rid marker appears anywhere on the wire.""" + weka_dir = _write_linear_fixture(tmp_path / "lin_none") + result = await cli.run( + _build_cmd(weka_dir, aiperf_mock_server.url, "none"), timeout=300.0 + ) + records = _profiling_records(result) + assert records, f"no profiling records\n{(result.log or '')[-1500:]}" + for rec in records: + for role, content in ( + (m.get("role"), m.get("content")) + for m in _payload_dict(rec).get("messages", []) + if isinstance(m, dict) + ): + if isinstance(content, str): + assert not _RID_RE.search(content), ( + f"NONE target leaked a marker into {role} content: {content[:80]!r}" + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_spawn_subagent_children_busted_with_own_marker_real_subprocess( + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, +): + """SPAWN subagent children get their OWN cache-bust marker (busted), distinct + from the parent root's marker -- the reachable production fan-out path. This + exercises the non-FORK branch of the worker's cache-bust guard end-to-end. + """ + weka_dir = _write_subagent_fixture(tmp_path / "sa", num_traces=5) + result = await cli.run( + _build_cmd(weka_dir, aiperf_mock_server.url, "first_turn_prefix"), + timeout=300.0, + ) + records = _profiling_records(result) + assert records, f"no profiling records\n{(result.log or '')[-1500:]}" + + # rids seen on root (depth==0) sessions, grouped by base conversation id. + root_rids_by_base: dict[str, set[str]] = defaultdict(set) + child_records: list = [] + for rec in records: + payload = _payload_dict(rec) + assert _max_rids_in_any_message(payload) <= 1, ( + f"stacked markers in conv={rec.metadata.conversation_id}" + ) + m = _RID_RE.search(_content_of(payload, "user") or "") + conv = rec.metadata.conversation_id or "" + if rec.metadata.agent_depth and "::sa:" in conv: + child_records.append((rec, m.group(0) if m else None)) + elif m: + root_rids_by_base[conv].add(m.group(0)) + + assert child_records, ( + "expected SPAWN subagent child records (agent_depth>0, '::sa:' in conv id) " + "-- the subagent fan-out did not materialize" + ) + for rec, child_rid in child_records: + conv = rec.metadata.conversation_id or "" + base = conv.split("::sa:")[0] + assert child_rid is not None, ( + f"SPAWN child {conv} carries no cache-bust marker (not busted)" + ) + # The child's marker must be its OWN, not inherited from the parent root. + assert child_rid not in root_rids_by_base.get(base, set()), ( + f"SPAWN child {conv} reused a parent-root marker {child_rid} " + f"(should be independently busted): root rids={root_rids_by_base.get(base)}" + ) diff --git a/tests/integration/test_agentic_replay_cache_bust_reset.py b/tests/integration/test_agentic_replay_cache_bust_reset.py new file mode 100644 index 0000000000..54cc79dec6 --- /dev/null +++ b/tests/integration/test_agentic_replay_cache_bust_reset.py @@ -0,0 +1,222 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""REAL integration test: cache-bust markers survive reset_context turns. + +Unlike the ``component_integration`` counterpart (single process, +FakeCommunication), this spins up the full ``aiperf`` subprocess against the +shared mock server over real ZMQ + real workers + real HTTP, so it also covers +the serialization / multiprocess session paths. + +A weka trace fixture is crafted so turn 1 is a non-monotonic LCP cut, which the +loader emits with ``reset_context=True``. Two-turn traces => agentic_replay +resumes profiling at turn 1, making the reset turn the profiling turn. The run +asserts that cache-bust markers are actually injected into the wire payload and +land on the post-reset prefix (the reset-semantics fix), with no stacking. +""" + +from __future__ import annotations + +import json +import re +from collections import defaultdict +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from aiperf.dataset.loader.weka_trace import WekaTraceLoader +from tests.harness.utils import AIPerfCLI, AIPerfMockServer + +_MODEL = "claude-opus-4-5-20251101" +_TOKENIZER = "openai/gpt-oss-120b" # pre-cached + offline in integration conftest +_BLOCK_SIZE = 16 +_RID_RE = re.compile(r"\[rid:[0-9a-f]{12}\]") + + +def _req(t: float, hash_ids: list[int], in_tokens: int) -> dict: + return { + "t": t, + "type": "n", + "model": _MODEL, + "in": in_tokens, + "out": 8, + "hash_ids": hash_ids, + "input_types": ["text"], + "output_types": ["text"], + "stop": "end_turn", + "api_time": 0.05, + "think_time": 0.0, + } + + +def _write_reset_fixture(target_dir: Path, *, num_traces: int = 6) -> Path: + """Weka traces whose turn 1 is a non-monotonic LCP cut (reset_context=True). + + Turn 0 is a single 5-block user segment ``[1,2,3,4,5]``; turn 1 shares only + ``[1,2]`` (LCP=2), landing inside the emitted segment -> reset. See the + component-integration counterpart for the full rationale. + """ + target_dir.mkdir(parents=True, exist_ok=True) + for n in range(1, num_traces + 1): + cut_a, cut_b = 100 + 2 * n, 101 + 2 * n + trace = { + "id": f"reset_trace_{n:02d}", + "models": [_MODEL], + "block_size": _BLOCK_SIZE, + "hash_id_scope": "local", + "tool_tokens": 0, + "system_tokens": 0, + "requests": [ + _req(0.0, [1, 2, 3, 4, 5], 5 * _BLOCK_SIZE), + _req(1.0, [1, 2, cut_a, cut_b], 4 * _BLOCK_SIZE), + ], + } + (target_dir / f"reset_trace_{n:02d}.json").write_text(json.dumps(trace)) + return target_dir + + +def _fixture_produces_reset(trace_file: Path) -> bool: + """Deterministically confirm the fixture yields a reset_context turn.""" + uc = MagicMock() + uc.input.random_seed = 0 + uc.input.fixed_schedule_start_offset = None + uc.input.fixed_schedule_end_offset = None + uc.input.ignore_trace_delays = False + uc.input.use_think_time_only = False + uc.input.synthesis.max_isl = None + uc.input.synthesis.max_osl = None + uc.input.max_context_length = None + uc.input.synthesis.should_synthesize.return_value = False + uc.input.prompt.input_tokens.block_size = None + uc.tokenizer.trust_remote_code = False + uc.tokenizer.revision = None + uc.tokenizer.name = "test-tok" + uc.endpoint.model_names = [_MODEL] + uc.loadgen.inter_turn_delay_cap_seconds = None + + loader = WekaTraceLoader(filename=str(trace_file), user_config=uc) + loader.synthesize_prompts_from_hash_ids = lambda rs: {r.key: "p" for r in rs} + pg = MagicMock() + pg._corpus_size = 10000 + pg._tokenized_corpus = list(range(10000)) + pg.tokenizer.decode = lambda tokens: f"decoded-{len(tokens)}" + loader.prompt_generator = pg + loader._tokenizer_name = "t" + loader._trust_remote_code = False + loader._tokenizer_revision = None + loader._block_size = _BLOCK_SIZE + convs = loader.convert_to_conversations(loader.load_dataset()) + return any(t.reset_context for c in convs for t in c.turns) + + +def _payload_dict(record) -> dict: + if record.payload is not None: + return record.payload + if record.payload_bytes is not None: + return json.loads(record.payload_bytes) + return {} + + +def _first_user_content(payload: dict) -> str | None: + for msg in payload.get("messages", []): + if isinstance(msg, dict) and msg.get("role") == "user": + content = msg.get("content") + return content if isinstance(content, str) else None + return None + + +def _max_rids_in_any_message(payload: dict) -> int: + counts = [0] + for msg in payload.get("messages", []): + if isinstance(msg, dict) and isinstance(msg.get("content"), str): + counts.append(len(_RID_RE.findall(msg["content"]))) + return max(counts) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_agentic_replay_cache_bust_marker_survives_reset_real_subprocess( + cli: AIPerfCLI, + aiperf_mock_server: AIPerfMockServer, + tmp_path: Path, +): + weka_dir = _write_reset_fixture(tmp_path / "weka_reset", num_traces=6) + + # Guard: the fixture genuinely produces a reset_context turn, so the + # end-to-end assertions are not vacuous. + trace_file = next(weka_dir.glob("*.json")) + assert _fixture_produces_reset(trace_file), ( + "fixture no longer triggers reset_context; the reset path is not exercised" + ) + + result = await cli.run( + f""" + aiperf profile \ + --model claude-haiku-4-5-20251001 \ + --model {_MODEL} \ + --url {aiperf_mock_server.url} \ + --endpoint-type chat \ + --streaming \ + --custom-dataset-type weka_trace \ + --input-file {weka_dir} \ + --no-fixed-schedule \ + --benchmark-duration 8 \ + --concurrency 3 \ + --random-seed 42 \ + --tokenizer {_TOKENIZER} \ + --extra-inputs ignore_eos:true \ + --workers-max 2 \ + --scenario inferencex-agentx-mvp \ + --unsafe-override \ + --cache-bust first_turn_prefix \ + --export-level raw \ + --ui simple + """, + timeout=300.0, + ) + + assert result.raw_records is not None and len(result.raw_records) > 0, ( + "profile_export_raw.jsonl must exist and be non-empty\n" + f"{(result.log or '')[-2000:]}" + ) + + by_session: dict[str, list] = defaultdict(list) + reset_turn_records: list = [] + for rec in result.raw_records: + if rec.metadata.benchmark_phase != "profiling": + continue + assert _max_rids_in_any_message(_payload_dict(rec)) <= 1, ( + f"stacked rid markers: conv={rec.metadata.conversation_id} " + f"ti={rec.metadata.turn_index}" + ) + xcorr = rec.metadata.x_correlation_id + if xcorr is not None: + by_session[xcorr].append(rec) + if rec.metadata.turn_index == 1: + reset_turn_records.append(rec) + + # Actual markers: one distinct rid per profiling session. + session_rids: list[str] = [] + for xcorr, records in by_session.items(): + rids: set[str] = set() + for rec in records: + m = _RID_RE.search(_first_user_content(_payload_dict(rec)) or "") + if m: + rids.add(m.group(0)) + assert len(rids) == 1, f"session={xcorr}: expected exactly one rid; got {rids}" + session_rids.append(next(iter(rids))) + + assert len(session_rids) >= 2, f"need >=2 marked sessions; got {len(session_rids)}" + assert len(set(session_rids)) == len(session_rids), ( + f"marker collision across sessions: {session_rids}" + ) + + # Reset semantics: reset-turn (turn 1) requests carry the marker on the + # post-reset prefix. Without the reset fix these would be unmarked. + assert reset_turn_records, "expected at least one reset-turn (turn_index==1) record" + for rec in reset_turn_records: + fu = _first_user_content(_payload_dict(rec)) + assert fu is not None and _RID_RE.search(fu), ( + f"reset turn wire prefix unmarked (marker lost across reset): " + f"conv={rec.metadata.conversation_id} first_user={fu!r}" + ) diff --git a/tests/unit/workers/test_worker_cache_bust_injection.py b/tests/unit/workers/test_worker_cache_bust_injection.py index 93d9fd7ea4..fb533d2f52 100644 --- a/tests/unit/workers/test_worker_cache_bust_injection.py +++ b/tests/unit/workers/test_worker_cache_bust_injection.py @@ -1,10 +1,12 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from aiperf.common.enums import CacheBustTarget, CreditPhase +import pytest + +from aiperf.common.enums import CacheBustTarget, ConversationBranchMode, CreditPhase from aiperf.common.models.dataset_models import Conversation, Text, Turn from aiperf.credit.structs import Credit -from aiperf.workers.session_manager import UserSession +from aiperf.workers.session_manager import UserSession, UserSessionManager from aiperf.workers.worker import ( _apply_cache_bust, _apply_cache_bust_to_system_message, @@ -99,12 +101,64 @@ def test_inject_marker_into_raw_messages_suffix(): assert raw[0]["content"] == "you are helpful" + _SUFFIX_MARKER +def test_inject_marker_into_raw_messages_prefix_idempotent(): + """In DELTAS mode turn_list[0] is a single shared object re-visited every + credit; re-injecting the same marker must NOT stack it.""" + raw = [{"role": "system", "content": "you are helpful"}] + _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) + assert raw[0]["content"] == _PREFIX_MARKER + "you are helpful" + + +def test_inject_marker_into_raw_messages_suffix_idempotent(): + raw = [{"role": "system", "content": "you are helpful"}] + _inject_marker_into_raw_messages(raw, _SUFFIX_MARKER, is_prefix=False) + _inject_marker_into_raw_messages(raw, _SUFFIX_MARKER, is_prefix=False) + assert raw[0]["content"] == "you are helpful" + _SUFFIX_MARKER + + +def test_inject_marker_into_raw_messages_multimodal_idempotent(): + raw = [{"role": "system", "content": [{"type": "text", "text": "hi"}]}] + _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) + assert raw[0]["content"] == [ + {"type": "text", "text": _PREFIX_MARKER.strip()}, + {"type": "text", "text": "hi"}, + ] + + def test_inject_marker_no_system_role_is_noop(): raw = [{"role": "user", "content": "hi"}] _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) assert raw[0]["content"] == "hi" +def test_inject_first_user_turn_idempotent_prefix(): + """Injection is unconditional per credit (seeded resume marks turn 0 every + credit); the helper must not stack the marker on repeated calls.""" + raw = [{"role": "user", "content": "hi"}] + _inject_marker_into_first_user_turn(raw, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_first_user_turn(raw, _PREFIX_MARKER, is_prefix=True) + assert raw[0]["content"] == _PREFIX_MARKER + "hi" + + +def test_inject_first_user_turn_idempotent_multimodal(): + raw = [{"role": "user", "content": [{"type": "text", "text": "hi"}]}] + _inject_marker_into_first_user_turn(raw, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_first_user_turn(raw, _PREFIX_MARKER, is_prefix=True) + assert raw[0]["content"] == [ + {"type": "text", "text": _PREFIX_MARKER.strip()}, + {"type": "text", "text": "hi"}, + ] + + +def test_inject_first_user_text_idempotent(): + turn = Turn(raw_messages=None, texts=[Text(contents=["hello"])]) + _inject_marker_into_first_user_text(turn, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_first_user_text(turn, _PREFIX_MARKER, is_prefix=True) + assert turn.texts[0].contents[0] == _PREFIX_MARKER + "hello" + + def test_inject_marker_empty_raw_is_noop(): raw: list[dict] = [] _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) @@ -181,7 +235,9 @@ def test_system_prefix_uses_existing_raw_system_role_when_no_conversation_system assert msgs[1]["content"] == "hi" -def test_system_prefix_fallback_no_op_on_turn_index_gt_zero(): +def test_system_prefix_fallback_marks_first_user_on_turn_index_gt_zero(): + """SYSTEM_PREFIX with no system anywhere falls back to the first user turn, + and now injects every credit (seeded-resume fix) rather than only turn 0.""" raw = [{"role": "user", "content": "hi"}] session = _make_session(raw, num_turns=2) credit = _make_credit( @@ -194,7 +250,7 @@ def test_system_prefix_fallback_no_op_on_turn_index_gt_zero(): out = _apply_cache_bust(session, credit, system_message=None) assert out is None - assert session.turn_list[-1].raw_messages[0]["content"] == "hi" + assert session.turn_list[-1].raw_messages[0]["content"] == _PREFIX_MARKER + "hi" def test_first_turn_prefix_unaffected_by_system_message_presence(): @@ -630,3 +686,888 @@ def test_apply_system_prefix_no_system_under_deltas_falls_back_to_turn_0_user(): assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "hi" assert session.turn_list[1].raw_messages[1]["content"] == "follow up" + + +# ============================================================================= +# reset_context re-injection (FIRST_TURN_*) +# ============================================================================= +# A turn carrying ``reset_context=True`` makes the endpoint's build_messages +# discard every accumulated prior turn and start the wire payload fresh from +# that turn's raw_messages. The turn-0 marker is no longer in the effective +# prefix, so the marker must be RE-APPLIED to the reset turn — otherwise every +# recycled play of the trace replays a byte-identical post-reset prefix and the +# server's prefix cache warms across plays (the exact thing cache-bust prevents). + + +def _make_delta_session_with_resets( + turns_raw: list[list[dict] | None], reset_flags: list[bool] +) -> UserSession: + """Like ``_make_delta_session`` but sets ``reset_context`` per turn.""" + turns = [ + Turn(raw_messages=raw, reset_context=reset) + for raw, reset in zip(turns_raw, reset_flags, strict=True) + ] + conversation = Conversation(session_id="conv_test", turns=list(turns)) + return UserSession( + x_correlation_id="xcorr_test", + num_turns=len(turns), + conversation=conversation, + turn_list=list(turns), + ) + + +def test_first_turn_prefix_reapplied_on_reset_context_turn(): + """FIRST_TURN_PREFIX at turn_index > 0 must inject into the reset turn (the + new wire prefix), not be skipped as it is for ordinary later turns.""" + turn_0 = [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": "hi"}, + ] + # Reset turn: build_messages discards turn 0 and starts here. + turn_1_reset = [ + {"role": "system", "content": "fresh rules"}, + {"role": "user", "content": "new prefix"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset], reset_flags=[False, True] + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + # The reset turn's first user message carries the marker. + assert ( + session.turn_list[1].raw_messages[1]["content"] == _PREFIX_MARKER + "new prefix" + ) + # Turn 0 (discarded from the wire) is left untouched. + assert session.turn_list[0].raw_messages[1]["content"] == "hi" + + +def test_first_turn_suffix_reapplied_on_reset_context_turn(): + turn_0 = [{"role": "user", "content": "hi"}] + turn_1_reset = [{"role": "user", "content": "new prefix"}] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset], reset_flags=[False, True] + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_SUFFIX, + marker=_SUFFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert ( + session.turn_list[1].raw_messages[0]["content"] == "new prefix" + _SUFFIX_MARKER + ) + + +def test_first_turn_prefix_marks_prefix_turn_on_ordinary_later_turn(): + """A non-reset turn at index > 0 re-marks the shared turn-0 prefix + (idempotent) and leaves the later turn's own user content untouched.""" + turn_0 = [{"role": "user", "content": "hi"}] + turn_1 = [ + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "follow up"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1], reset_flags=[False, False] + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "hi" + assert session.turn_list[1].raw_messages[1]["content"] == "follow up" + + +def test_first_turn_prefix_reset_on_turn_zero_uses_turn_zero_path_once(): + """A reset flag on turn 0 still resolves through the turn-0 path and injects + exactly once (no double application).""" + turn_0_reset = [{"role": "user", "content": "hi"}] + session = _make_delta_session_with_resets([turn_0_reset], reset_flags=[True]) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=0, + num_turns=1, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "hi" + + +# ============================================================================= +# Seeded mid-trajectory resume (FIRST_TURN_* / SYSTEM_* sub-path 3) +# ============================================================================= +# Agentic replay can resume a trajectory at turn k_i > 0. The worker's +# advance_turn back-fills turns 0..k_i into turn_list, so turn 0 (the real wire +# prefix) is present even though credit.turn_index > 0. The turn-0 gate missed +# it; injection now runs every credit and is idempotent. + + +def test_first_turn_prefix_marks_seeded_turn_zero_on_resume(): + """FIRST_TURN_PREFIX at turn_index > 0 with no reset must mark the seeded + turn 0 (the conversation's opening prefix), not be skipped.""" + turn_0 = [{"role": "user", "content": "u0"}] + turn_1 = [ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ] + turn_2 = [ + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1, turn_2], reset_flags=[False, False, False] + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=2, + num_turns=3, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "u0" + # Later seeded turns' user messages are untouched. + assert session.turn_list[2].raw_messages[1]["content"] == "u2" + + +def test_first_turn_prefix_resume_then_next_turn_no_stacking(): + """The seeded turn 0 is shared across the session's turns; processing the + resume credit then the next turn must mark it exactly once.""" + turn_0 = [{"role": "user", "content": "u0"}] + turn_1 = [ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1], reset_flags=[False, False] + ) + + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ), + system_message=None, + ) + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "u0" + + # Next turn on the same session re-runs injection; idempotent -> no stack. + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ), + system_message=None, + ) + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "u0" + + +def test_system_prefix_subpath3_marks_seeded_turn_zero_on_resume(): + """SYSTEM_PREFIX with no system anywhere falls back to first-user; under a + seeded resume it must still mark the seeded turn 0.""" + turn_0 = [{"role": "user", "content": "u0"}] + turn_1 = [ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1], reset_flags=[False, False] + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "u0" + + +# ============================================================================= +# FORK children: inherit the parent's prefix, never bust +# ============================================================================= +# A FORK child seeds turn_list = list(parent.turn_list) (SHARED Turn objects, +# same worker). It shares the parent's KV cache by design, so cache-bust must be +# a complete no-op: the child inherits the parent's already-injected marker via +# the shared object and must NOT re-bust it (which would diverge the prefix from +# the parent -> cache miss -> and corrupt the parent's shared Turn). + + +def _make_fork_child_session( + turns: list[Turn], *, num_turns: int | None = None +) -> UserSession: + conversation = Conversation(session_id="child", turns=list(turns)) + return UserSession( + x_correlation_id="child_xcorr", + num_turns=num_turns if num_turns is not None else len(turns), + conversation=conversation, + turn_list=list(turns), + parent_correlation_id="parent_xcorr", + branch_mode=ConversationBranchMode.FORK, + ) + + +def test_fork_child_first_turn_is_noop_inherits_parent_marker(): + # The shared turn 0 already carries the PARENT's marker (injected by the + # parent's session). The child must leave it untouched. + parent_marked_t0 = Turn( + raw_messages=[{"role": "user", "content": "[rid:PARENT00000]\n\nu0"}], + reset_context=False, + ) + child_turn = Turn( + raw_messages=[ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ], + reset_context=False, + ) + session = _make_fork_child_session([parent_marked_t0, child_turn]) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker="[rid:CHILD000000]\n\n", + turn_index=1, + num_turns=2, + ) + + out = _apply_cache_bust(session, credit, system_message=None) + + assert out is None + # Parent's marker preserved verbatim; the child's marker is NOT added. + assert session.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nu0" + + +def test_fork_child_system_target_is_noop(): + parent_marked_sys = Turn( + raw_messages=[ + {"role": "system", "content": "[rid:PARENT00000]\n\nS0"}, + {"role": "user", "content": "u0"}, + ], + reset_context=False, + ) + child_turn = Turn( + raw_messages=[ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ], + reset_context=False, + ) + session = _make_fork_child_session([parent_marked_sys, child_turn]) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker="[rid:CHILD000000]\n\n", + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nS0" + + +def test_spawn_child_is_busted_normally(): + """SPAWN children start fresh (no shared parent turns), so they are busted + like a root session.""" + t0 = Turn(raw_messages=[{"role": "user", "content": "u0"}], reset_context=False) + conversation = Conversation(session_id="spawn", turns=[t0]) + session = UserSession( + x_correlation_id="spawn_xcorr", + num_turns=1, + conversation=conversation, + turn_list=[t0], + parent_correlation_id="parent_xcorr", + branch_mode=ConversationBranchMode.SPAWN, + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=0, + num_turns=1, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "u0" + + +# ============================================================================= +# Buried reset_context (reset turn is NOT the current turn) +# ============================================================================= +# build_messages restarts the wire array at every reset_context turn, so the +# effective prefix is the LAST reset turn in turn_list. That turn may sit +# mid-history (seeded on a resume, never dispatched as the current turn), so +# inspecting only turn_list[-1] would mark the discarded turn 0 instead. + + +def test_first_turn_prefix_marks_buried_reset_turn_not_discarded_turn_zero(): + turn_0 = [{"role": "user", "content": "u0"}] + turn_1_reset = [ + {"role": "system", "content": "S1"}, + {"role": "user", "content": "u1"}, + ] + turn_2 = [ + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset, turn_2], reset_flags=[False, True, False] + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=2, + num_turns=3, + ) + + _apply_cache_bust(session, credit, system_message=None) + + # Marker lands on the buried reset turn (the real wire prefix), not turn 0. + assert session.turn_list[1].raw_messages[1]["content"] == _PREFIX_MARKER + "u1" + assert session.turn_list[0].raw_messages[0]["content"] == "u0" + assert session.turn_list[2].raw_messages[1]["content"] == "u2" + + +def test_system_prefix_marks_buried_reset_turn_system(): + turn_0 = [ + {"role": "system", "content": "S0"}, + {"role": "user", "content": "u0"}, + ] + turn_1_reset = [ + {"role": "system", "content": "S1"}, + {"role": "user", "content": "u1"}, + ] + turn_2 = [ + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset, turn_2], reset_flags=[False, True, False] + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=2, + num_turns=3, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[1].raw_messages[0]["content"] == _PREFIX_MARKER + "S1" + # Discarded turn 0 system left untouched. + assert session.turn_list[0].raw_messages[0]["content"] == "S0" + + +def test_first_turn_prefix_marks_only_last_of_multiple_resets(): + """With two resets, only the last (the effective prefix) is marked.""" + turn_0_reset = [{"role": "user", "content": "u0"}] + turn_1_reset = [{"role": "user", "content": "u1"}] + turn_2 = [ + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + session = _make_delta_session_with_resets( + [turn_0_reset, turn_1_reset, turn_2], reset_flags=[True, True, False] + ) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=2, + num_turns=3, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[1].raw_messages[0]["content"] == _PREFIX_MARKER + "u1" + assert session.turn_list[0].raw_messages[0]["content"] == "u0" + + +# ============================================================================= +# reset_context re-injection (SYSTEM_*) +# ============================================================================= +# Same defect as FIRST_TURN_*, but for the SYSTEM_* sub-paths that mutate a +# turn's raw_messages. Sub-path 1 (Conversation-level system_message) is safe +# because the marker rides on RequestInfo.system_message and is re-emitted every +# turn independent of build_messages' reset. Sub-paths 2 (raw role=="system" in +# a turn) and 3 (no system -> first-user fallback) marked the discarded turn 0 +# instead of the reset turn's fresh prefix; these tests pin the fix. + + +def test_system_prefix_reapplied_on_reset_turn_with_own_system(): + """Sub-path 2 under reset: the reset turn's own system message (the new wire + prefix), not the discarded turn 0 system, must carry the marker.""" + turn_0 = [ + {"role": "system", "content": "S0"}, + {"role": "user", "content": "u0"}, + ] + turn_1_reset = [ + {"role": "system", "content": "S1"}, + {"role": "user", "content": "u1"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset], reset_flags=[False, True] + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + out = _apply_cache_bust(session, credit, system_message=None) + + assert out is None + assert session.turn_list[1].raw_messages[0]["content"] == _PREFIX_MARKER + "S1" + # Discarded turn 0 system left untouched on this credit. + assert session.turn_list[0].raw_messages[0]["content"] == "S0" + + +def test_system_suffix_reapplied_on_reset_turn_with_own_system(): + turn_0 = [{"role": "system", "content": "S0"}] + turn_1_reset = [ + {"role": "system", "content": "S1"}, + {"role": "user", "content": "u1"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset], reset_flags=[False, True] + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_SUFFIX, + marker=_SUFFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[1].raw_messages[0]["content"] == "S1" + _SUFFIX_MARKER + assert session.turn_list[0].raw_messages[0]["content"] == "S0" + + +def test_system_prefix_reset_no_system_falls_back_to_reset_turn_user(): + """Sub-path 3 under reset: no system anywhere, so the marker falls back to + the reset turn's first user message (its new prefix), not turn 0's.""" + turn_0 = [{"role": "user", "content": "u0"}] + turn_1_reset = [{"role": "user", "content": "u1"}] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset], reset_flags=[False, True] + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[1].raw_messages[0]["content"] == _PREFIX_MARKER + "u1" + assert session.turn_list[0].raw_messages[0]["content"] == "u0" + + +def test_system_prefix_subpath2_no_stacking_across_delta_turns(): + """Sub-path 2 dispatch: under DELTAS the shared turn_list[0] system is + re-visited on every credit. The marker must be injected once and not stack + turn-over-turn (the original 'inject every turn' design stacked here).""" + turn_0 = [ + {"role": "system", "content": "S0"}, + {"role": "user", "content": "u0"}, + ] + turn_1 = [ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1], reset_flags=[False, False] + ) + + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=0, + num_turns=2, + ), + system_message=None, + ) + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "S0" + + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ), + system_message=None, + ) + # Still exactly one marker, not stacked. + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "S0" + + +def test_system_prefix_conversation_message_safe_under_reset(): + """Sub-path 1 regression: a Conversation-level system_message is re-marked + every turn and rides on RequestInfo, so reset never strips it. The returned + string carries the marker and the raw turns stay untouched.""" + turns = [ + Turn( + raw_messages=[{"role": "user", "content": "u0"}], + reset_context=False, + ), + Turn( + raw_messages=[{"role": "user", "content": "u1"}], + reset_context=True, + ), + ] + conversation = Conversation( + session_id="conv_test", turns=list(turns), system_message="CONV" + ) + session = UserSession( + x_correlation_id="xcorr_test", + num_turns=2, + conversation=conversation, + turn_list=list(turns), + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + out = _apply_cache_bust(session, credit, system_message="CONV") + + assert out == _PREFIX_MARKER + "CONV" + assert session.turn_list[1].raw_messages[0]["content"] == "u1" + + +# ============================================================================= +# Extensive matrix: session-type x target x prefix-scenario interactions +# ============================================================================= +# These lock the full interaction surface that bit us repeatedly: FORK (shared, +# inherit-don't-bust) vs SPAWN/root (own prefix, bust) crossed with all four +# targets, multi-turn persistence, idempotency, and reset/seeded-resume combos. + + +_ALL_TARGETS = [ + CacheBustTarget.FIRST_TURN_PREFIX, + CacheBustTarget.FIRST_TURN_SUFFIX, + CacheBustTarget.SYSTEM_PREFIX, + CacheBustTarget.SYSTEM_SUFFIX, +] + + +def _marker_for(target: CacheBustTarget) -> str: + """Prefix targets need a trailing-newline marker; suffix targets a leading one + (mirrors build_cache_bust_marker's placement).""" + return ( + _SUFFIX_MARKER + if target in (CacheBustTarget.FIRST_TURN_SUFFIX, CacheBustTarget.SYSTEM_SUFFIX) + else _PREFIX_MARKER + ) + + +# ---- FORK is a no-op for every target ------------------------------------- + + +@pytest.mark.parametrize("target", _ALL_TARGETS) +def test_fork_child_is_noop_for_all_targets(target: CacheBustTarget): + """A FORK child must never re-bust its inherited prefix, regardless of target. + The shared turn carries only the parent's marker; the child adds nothing.""" + parent_marked = Turn( + raw_messages=[ + {"role": "system", "content": "[rid:PARENT00000]\n\nS0"}, + {"role": "user", "content": "[rid:PARENT00000]\n\nu0"}, + ], + reset_context=False, + ) + child_turn = Turn( + raw_messages=[ + {"role": "assistant", "content": "a0"}, + {"role": "user", "content": "u1"}, + ], + reset_context=False, + ) + session = _make_fork_child_session([parent_marked, child_turn]) + before_sys = session.turn_list[0].raw_messages[0]["content"] + before_user = session.turn_list[0].raw_messages[1]["content"] + credit = _make_credit( + target=target, marker="[rid:CHILD000000]\n\n", turn_index=1, num_turns=2 + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[0]["content"] == before_sys + assert session.turn_list[0].raw_messages[1]["content"] == before_user + + +def test_fork_child_noop_even_when_conversation_system_message_present(): + """SYSTEM sub-path 1 (Conversation-level system_message) is also skipped for + a FORK child: the child returns it unchanged rather than applying its own + marker (which would diverge from the parent's system prefix).""" + parent_marked = Turn( + raw_messages=[{"role": "user", "content": "u0"}], reset_context=False + ) + conversation = Conversation( + session_id="child", turns=[parent_marked], system_message="CONV" + ) + session = UserSession( + x_correlation_id="child_xcorr", + num_turns=1, + conversation=conversation, + turn_list=[parent_marked], + parent_correlation_id="parent_xcorr", + branch_mode=ConversationBranchMode.FORK, + ) + credit = _make_credit( + target=CacheBustTarget.SYSTEM_PREFIX, marker=_PREFIX_MARKER, num_turns=1 + ) + + out = _apply_cache_bust(session, credit, system_message="CONV") + + # Returned unchanged (NOT marker + "CONV"). + assert out == "CONV" + + +def test_fork_child_multi_turn_prefix_stays_single_marked(): + """Processing several FORK-child credits never stacks onto the shared turn 0.""" + shared_t0 = Turn( + raw_messages=[{"role": "user", "content": "[rid:PARENT00000]\n\nu0"}], + reset_context=False, + ) + later = Turn( + raw_messages=[ + {"role": "assistant", "content": "a"}, + {"role": "user", "content": "u_later"}, + ], + reset_context=False, + ) + session = _make_fork_child_session([shared_t0, later], num_turns=2) + for ti in (1, 1, 1): + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker="[rid:CHILD000000]\n\n", + turn_index=ti, + num_turns=2, + ), + system_message=None, + ) + assert session.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nu0" + + +def test_fork_child_with_own_reset_is_still_noop(): + """A FORK child carrying its OWN reset_context turn is still a no-op: FORK + never busts. (Documents current behavior — a child-introduced reset prefix + is not independently busted; revisit if that workload appears.)""" + shared_t0 = Turn( + raw_messages=[{"role": "user", "content": "[rid:PARENT00000]\n\nu0"}], + reset_context=False, + ) + child_reset = Turn( + raw_messages=[{"role": "user", "content": "child fresh prefix"}], + reset_context=True, + ) + session = _make_fork_child_session([shared_t0, child_reset], num_turns=2) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker="[rid:CHILD000000]\n\n", + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[1].raw_messages[0]["content"] == "child fresh prefix" + assert session.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nu0" + + +# ---- Realistic FORK lifecycle through UserSessionManager seeding ----------- + + +def test_fork_lifecycle_child_inherits_parents_marked_turn_object(): + """End-to-end at the session layer: a parent marks turn 0 in place, a FORK + child is seeded from the parent via create_and_store (sharing the SAME Turn + object), and the child's cache-bust is a no-op — so the child sends the + parent's exact marked prefix (byte-identical => prefix-cache hit).""" + mgr = UserSessionManager() + t0 = Turn(raw_messages=[{"role": "user", "content": "u0"}], reset_context=False) + parent_conv = Conversation(session_id="root", turns=[t0]) + parent = mgr.create_and_store("P", parent_conv, num_turns=1) + parent.advance_turn(0) + _apply_cache_bust( + parent, + _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker="[rid:PARENT00000]\n\n", + turn_index=0, + num_turns=1, + ), + system_message=None, + ) + assert parent.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nu0" + + # FORK child seeds turn_list from the parent (shallow copy -> shared Turn). + child = mgr.create_and_store( + "C", + parent_conv, + num_turns=1, + parent_correlation_id="P", + branch_mode=ConversationBranchMode.FORK, + ) + # The child shares the parent's marked turn-0 object by identity. + assert child.turn_list[0] is parent.turn_list[0] + + _apply_cache_bust( + child, + _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker="[rid:CHILD000000]\n\n", + turn_index=0, + num_turns=1, + ), + system_message=None, + ) + + # No-op: still exactly the parent's marker, no child marker, no stacking. + assert child.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nu0" + assert parent.turn_list[0].raw_messages[0]["content"] == "[rid:PARENT00000]\n\nu0" + + +# ---- SPAWN children and root sessions ARE busted (across targets) ---------- + + +@pytest.mark.parametrize("target", _ALL_TARGETS) +def test_spawn_child_is_busted_for_all_targets(target: CacheBustTarget): + marker = _marker_for(target) + raw = [ + {"role": "system", "content": "S0"}, + {"role": "user", "content": "u0"}, + ] + turn = Turn(raw_messages=[dict(m) for m in raw], reset_context=False) + conversation = Conversation(session_id="spawn", turns=[turn]) + session = UserSession( + x_correlation_id="spawn_xcorr", + num_turns=1, + conversation=conversation, + turn_list=[turn], + parent_correlation_id="parent_xcorr", + branch_mode=ConversationBranchMode.SPAWN, + ) + + _apply_cache_bust( + session, + _make_credit(target=target, marker=marker, turn_index=0, num_turns=1), + system_message=None, + ) + + msgs = session.turn_list[0].raw_messages + if target in (CacheBustTarget.SYSTEM_PREFIX, CacheBustTarget.SYSTEM_SUFFIX): + carrier = msgs[0]["content"] # system + else: + carrier = msgs[1]["content"] # first user + assert _PREFIX_MARKER.strip() in carrier + + +# ---- Idempotency stress: one session, many credits, single marker ---------- + + +def test_within_session_many_credits_single_marker_prefix(): + """A root session re-processed across many credits keeps exactly one marker + on the shared turn-0 object (idempotency holds turn-over-turn).""" + t0_raw = [{"role": "user", "content": "u0"}] + rest_raw = [ + [ + {"role": "assistant", "content": f"a{i}"}, + {"role": "user", "content": f"u{i + 1}"}, + ] + for i in range(4) + ] + session = _make_delta_session_with_resets( + [t0_raw, *rest_raw], reset_flags=[False] * 5 + ) + for ti in range(5): + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=ti, + num_turns=5, + ), + system_message=None, + ) + assert session.turn_list[0].raw_messages[0]["content"] == _PREFIX_MARKER + "u0" + + +def test_within_session_many_credits_single_marker_suffix(): + session = _make_delta_session_with_resets( + [[{"role": "user", "content": "u0"}]], reset_flags=[False] + ) + for _ in range(3): + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.FIRST_TURN_SUFFIX, + marker=_SUFFIX_MARKER, + turn_index=0, + num_turns=1, + ), + system_message=None, + ) + assert session.turn_list[0].raw_messages[0]["content"] == "u0" + _SUFFIX_MARKER + + +# ---- Seeded-resume x reset combinations (suffix coverage) ------------------ + + +def test_seeded_resume_with_buried_reset_suffix(): + """Buried reset + suffix target on a seeded resume: marker suffixes the reset + turn's first user (the effective prefix), not the discarded turn 0.""" + turn_0 = [{"role": "user", "content": "u0"}] + turn_1_reset = [{"role": "user", "content": "u1"}] + turn_2 = [ + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + session = _make_delta_session_with_resets( + [turn_0, turn_1_reset, turn_2], reset_flags=[False, True, False] + ) + _apply_cache_bust( + session, + _make_credit( + target=CacheBustTarget.FIRST_TURN_SUFFIX, + marker=_SUFFIX_MARKER, + turn_index=2, + num_turns=3, + ), + system_message=None, + ) + assert session.turn_list[1].raw_messages[0]["content"] == "u1" + _SUFFIX_MARKER + assert session.turn_list[0].raw_messages[0]["content"] == "u0"