From acc87e3a88897b839d31fdca8561ee2088380d63 Mon Sep 17 00:00:00 2001 From: weireweire <20922698+weireweire@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:58:19 +0800 Subject: [PATCH 1/2] Add Dynamo session control support --- src/aiperf/common/config/config_defaults.py | 2 + src/aiperf/common/config/endpoint_config.py | 34 +++ .../common/models/model_endpoint_info.py | 15 ++ src/aiperf/workers/worker.py | 93 +++++++- .../common/config/test_endpoint_config.py | 10 + .../unit/common/models/test_endpoint_info.py | 25 ++- tests/unit/workers/test_worker.py | 211 +++++++++++++++++- 7 files changed, 386 insertions(+), 4 deletions(-) diff --git a/src/aiperf/common/config/config_defaults.py b/src/aiperf/common/config/config_defaults.py index d2dbb509ff..fd9e627341 100644 --- a/src/aiperf/common/config/config_defaults.py +++ b/src/aiperf/common/config/config_defaults.py @@ -48,6 +48,8 @@ class EndpointDefaults: CONNECTION_REUSE_STRATEGY = ConnectionReuseStrategy.POOLED DOWNLOAD_VIDEO_CONTENT = False REQUEST_CONTENT_TYPE = None + USE_DYNAMO_CONV_AWARE_ROUTING = False + DYNAMO_SESSION_TIMEOUT_SECONDS = 300 # Readiness probe defaults. Timeout 0 disables the probe (the default); # any positive value enables it. Interval is only consulted when the # probe is enabled but is validated positive so mis-configuration diff --git a/src/aiperf/common/config/endpoint_config.py b/src/aiperf/common/config/endpoint_config.py index 967e00d417..7861dc5c84 100644 --- a/src/aiperf/common/config/endpoint_config.py +++ b/src/aiperf/common/config/endpoint_config.py @@ -316,6 +316,40 @@ def url(self) -> str: ), ] = EndpointDefaults.USE_SERVER_TOKEN_COUNT + use_dynamo_conv_aware_routing: Annotated[ + bool, + Field( + description=( + "Emit Dynamo nvext.session_control in OpenAI-compatible request " + "bodies so Dynamo can bind all turns from the same replayed " + "conversation lineage to the same backend worker. This is only " + "intended for Dynamo frontends that implement session_control." + ), + ), + CLIParameter( + name=( + "--use-dynamo-conv-aware-routing", + "--use-dynamo-session-control", + ), + group=Groups.ENDPOINT, + ), + ] = EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING + + dynamo_session_timeout_seconds: Annotated[ + int, + Field( + description=( + "Dynamo nvext.session_control timeout in seconds when " + "--use-dynamo-conv-aware-routing is enabled." + ), + ge=1, + ), + CLIParameter( + name=("--dynamo-session-timeout-seconds",), + group=Groups.ENDPOINT, + ), + ] = EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS + connection_reuse_strategy: Annotated[ ConnectionReuseStrategy, Field( diff --git a/src/aiperf/common/models/model_endpoint_info.py b/src/aiperf/common/models/model_endpoint_info.py index 376fffd390..dbfe02e39c 100644 --- a/src/aiperf/common/models/model_endpoint_info.py +++ b/src/aiperf/common/models/model_endpoint_info.py @@ -114,6 +114,15 @@ class EndpointInfo(AIPerfBaseModel): default=EndpointDefaults.USE_SERVER_TOKEN_COUNT, description="Use server-reported token counts from API usage fields instead of client-side tokenization.", ) + use_dynamo_conv_aware_routing: bool = Field( + default=EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING, + description="Emit Dynamo nvext.session_control for conversation-aware routing.", + ) + dynamo_session_timeout_seconds: int = Field( + default=EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS, + ge=1, + description="Timeout in seconds for Dynamo nvext.session_control sessions.", + ) connection_reuse_strategy: ConnectionReuseStrategy = Field( default=EndpointDefaults.CONNECTION_REUSE_STRATEGY, description="Transport connection reuse strategy.", @@ -164,6 +173,12 @@ def from_user_config(cls, user_config: UserConfig) -> "EndpointInfo": api_key=user_config.endpoint.api_key, use_legacy_max_tokens=user_config.endpoint.use_legacy_max_tokens, use_server_token_count=user_config.endpoint.use_server_token_count, + use_dynamo_conv_aware_routing=( + user_config.endpoint.use_dynamo_conv_aware_routing + ), + dynamo_session_timeout_seconds=( + user_config.endpoint.dynamo_session_timeout_seconds + ), connection_reuse_strategy=user_config.endpoint.connection_reuse_strategy, download_video_content=user_config.endpoint.download_video_content, request_content_type=user_config.endpoint.request_content_type, diff --git a/src/aiperf/workers/worker.py b/src/aiperf/workers/worker.py index 695fd0045b..e54cffbf7e 100644 --- a/src/aiperf/workers/worker.py +++ b/src/aiperf/workers/worker.py @@ -5,7 +5,7 @@ import asyncio import time import uuid -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import orjson @@ -460,6 +460,9 @@ def __init__( # high concurrency — the misconfiguration is the same for every credit. self._cache_bust_warning_shown: bool = False + # Worker-local bind cache for Dynamo conversation-aware routing. + self._dynamo_bound_session_ids: set[str] = set() + # Only used as a fallback when dataset client is not initialized # or was not available when the credit was dropped. Must be created here # so it can be attached to the worker lifecycle. @@ -938,6 +941,11 @@ def _create_request_info( credit = credit_context.credit if turns is None: turns = session.turn_list if session else [] + turns, payload_bytes = self._add_dynamo_session_control( + credit=credit, + turns=turns, + payload_bytes=payload_bytes, + ) return RequestInfo( model_endpoint=self.model_endpoint, credit_num=credit.id, @@ -967,6 +975,89 @@ def _create_request_info( else None, ) + def _add_dynamo_session_control( + self, + *, + credit: Credit, + turns: list[Turn], + payload_bytes: bytes | None, + ) -> tuple[list[Turn], bytes | None]: + """Inject Dynamo ``nvext.session_control`` when configured. + + The normal chat path merges ``Turn.extra_body`` into the wire payload. + Raw-payload paths bypass that formatter, so they are patched directly + when possible. ``payload_bytes`` is decoded only under the opt-in flag. + """ + endpoint = self.model_endpoint.endpoint + if not endpoint.use_dynamo_conv_aware_routing: + return turns, payload_bytes + + session_control = self._dynamo_session_control_for_credit(credit) + if payload_bytes is not None: + payload = orjson.loads(payload_bytes) + if not isinstance(payload, dict): + raise ValueError("Dynamo session_control requires object payload_bytes") + payload = self._merge_dynamo_session_control(payload, session_control) + if session_control.get("action") == "bind": + self._dynamo_bound_session_ids.add(session_control["session_id"]) + return turns, orjson.dumps(payload) + + if not turns: + return turns, payload_bytes + + last_turn = turns[-1] + updates: dict[str, Any] + if last_turn.raw_payload is not None: + raw_payload = self._merge_dynamo_session_control( + last_turn.raw_payload, + session_control, + ) + updates = {"raw_payload": raw_payload} + else: + updates = { + "extra_body": self._merge_dynamo_session_control( + last_turn.extra_body or {}, + session_control, + ) + } + + new_turns = list(turns) + new_turns[-1] = last_turn.model_copy(update=updates) + if session_control.get("action") == "bind": + self._dynamo_bound_session_ids.add(session_control["session_id"]) + return new_turns, payload_bytes + + def _dynamo_session_control_for_credit(self, credit: Credit) -> dict[str, Any]: + session_id = self._dynamo_session_id(credit) + session_control: dict[str, Any] = { + "session_id": session_id, + "timeout": self.model_endpoint.endpoint.dynamo_session_timeout_seconds, + } + if session_id not in self._dynamo_bound_session_ids: + session_control["action"] = "bind" + return session_control + + @staticmethod + def _dynamo_session_id(credit: Credit) -> str: + return credit.x_correlation_id + + @staticmethod + def _merge_dynamo_session_control( + payload: dict[str, Any], + session_control: dict[str, Any], + ) -> dict[str, Any]: + merged = dict(payload) + raw_nvext = merged.get("nvext") + nvext = dict(raw_nvext) if isinstance(raw_nvext, dict) else {} + raw_session_control = nvext.get("session_control") + merged_session_control = ( + dict(raw_session_control) if isinstance(raw_session_control, dict) else {} + ) + merged_session_control.update(session_control) + nvext["session_control"] = merged_session_control + merged["nvext"] = nvext + return merged + async def _retrieve_conversation( self, *, diff --git a/tests/unit/common/config/test_endpoint_config.py b/tests/unit/common/config/test_endpoint_config.py index 35c069c7f4..2fa94bae3b 100644 --- a/tests/unit/common/config/test_endpoint_config.py +++ b/tests/unit/common/config/test_endpoint_config.py @@ -26,6 +26,14 @@ def test_endpoint_config_defaults(): assert config.custom_endpoint == EndpointDefaults.CUSTOM_ENDPOINT assert config.streaming == EndpointDefaults.STREAMING assert config.url == EndpointDefaults.URL + assert ( + config.use_dynamo_conv_aware_routing + == EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING + ) + assert ( + config.dynamo_session_timeout_seconds + == EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS + ) def test_endpoint_config_custom_values(): @@ -49,6 +57,8 @@ def test_endpoint_config_custom_values(): "urls": ["http://custom-url"], "timeout_seconds": 10, "api_key": "custom_api_key", + "use_dynamo_conv_aware_routing": True, + "dynamo_session_timeout_seconds": 123, } config = EndpointConfig(**custom_values) for key, value in custom_values.items(): diff --git a/tests/unit/common/models/test_endpoint_info.py b/tests/unit/common/models/test_endpoint_info.py index 2d81e2a596..30a723f74d 100644 --- a/tests/unit/common/models/test_endpoint_info.py +++ b/tests/unit/common/models/test_endpoint_info.py @@ -4,8 +4,8 @@ import pytest -from aiperf.common.config import EndpointDefaults -from aiperf.common.models.model_endpoint_info import EndpointInfo +from aiperf.common.config import EndpointConfig, EndpointDefaults, UserConfig +from aiperf.common.models.model_endpoint_info import EndpointInfo, ModelEndpointInfo class TestEndpointInfoMultiURL: @@ -16,6 +16,14 @@ def test_single_url_default(self): info = EndpointInfo() assert info.base_urls == [EndpointDefaults.URL] assert info.base_url == EndpointDefaults.URL + assert ( + info.use_dynamo_conv_aware_routing + == EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING + ) + assert ( + info.dynamo_session_timeout_seconds + == EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS + ) def test_single_url_custom(self): """Custom single URL should work.""" @@ -35,6 +43,19 @@ def test_base_urls_must_have_at_least_one(self): with pytest.raises(ValueError): EndpointInfo(base_urls=[]) + def test_dynamo_session_control_from_user_config(self): + """Dynamo session-control fields should flow into runtime endpoint info.""" + user_config = UserConfig( + endpoint=EndpointConfig( + model_names=["test-model"], + use_dynamo_conv_aware_routing=True, + dynamo_session_timeout_seconds=123, + ) + ) + info = ModelEndpointInfo.from_user_config(user_config).endpoint + assert info.use_dynamo_conv_aware_routing is True + assert info.dynamo_session_timeout_seconds == 123 + class TestEndpointInfoGetUrl: """Tests for EndpointInfo.get_url() method.""" diff --git a/tests/unit/workers/test_worker.py b/tests/unit/workers/test_worker.py index e931000208..b559680cd1 100644 --- a/tests/unit/workers/test_worker.py +++ b/tests/unit/workers/test_worker.py @@ -7,16 +7,19 @@ from aiperf.common.config.service_config import ServiceConfig from aiperf.common.config.user_config import UserConfig -from aiperf.common.enums import CreditPhase +from aiperf.common.enums import CacheBustTarget, CreditPhase from aiperf.common.models import ( Conversation, ParsedResponse, ReasoningResponseData, RequestRecord, SSEMessage, + Text, TextResponseData, + Turn, ) from aiperf.credit.structs import Credit, CreditContext +from aiperf.workers.session_manager import UserSession from aiperf.workers.worker import Worker from tests.harness.fake_communication import FakeCommunication as FakeCommunication from tests.harness.fake_service_manager import FakeServiceManager as FakeServiceManager @@ -24,6 +27,54 @@ from tests.harness.fake_transport import FakeTransport as FakeTransport +def _make_credit( + *, + conversation_id: str = "conv-a", + x_correlation_id: str = "xcorr-a", + turn_index: int = 0, + phase: CreditPhase = CreditPhase.PROFILING, + cache_bust_marker: str | None = None, + parent_correlation_id: str | None = None, +) -> Credit: + return Credit( + id=1, + phase=phase, + conversation_id=conversation_id, + x_correlation_id=x_correlation_id, + turn_index=turn_index, + num_turns=max(turn_index + 1, 1), + issued_at_ns=0, + parent_correlation_id=parent_correlation_id, + cache_bust_marker=cache_bust_marker, + cache_bust_target=( + CacheBustTarget.FIRST_TURN_PREFIX + if cache_bust_marker + else CacheBustTarget.NONE + ), + ) + + +def _make_credit_context(credit: Credit) -> CreditContext: + return CreditContext(credit=credit, drop_perf_ns=0) + + +def _make_session( + *, + conversation_id: str = "conv-a", + x_correlation_id: str = "xcorr-a", + turn_index: int = 0, + turn: Turn | None = None, +) -> UserSession: + current_turn = turn or Turn(role="user", texts=[Text(contents=["hello"])]) + return UserSession( + x_correlation_id=x_correlation_id, + num_turns=max(turn_index + 1, 1), + conversation=Conversation(session_id=conversation_id, turns=[current_turn]), + turn_list=[current_turn], + turn_index=turn_index, + ) + + @pytest.fixture async def mock_worker( user_config: UserConfig, @@ -130,6 +181,164 @@ async def test_process_response_mixed_reasoning_and_text_combines_content( turn = await mock_worker._process_response(RequestRecord()) assert turn.texts[0].contents == ["HelloWorld"] + async def test_create_request_info_dynamo_session_control_default_disabled( + self, mock_worker + ): + """Default config should not add Dynamo request-body extensions.""" + turn = Turn(role="user", texts=[Text(contents=["hello"])]) + credit = _make_credit( + turn_index=3, + cache_bust_marker="[rid:stable]\n\n", + ) + request_info = mock_worker._create_request_info( + x_request_id="req-1", + credit_context=_make_credit_context(credit), + session=_make_session(turn_index=3, turn=turn), + ) + + assert request_info.turns[-1].extra_body is None + assert turn.extra_body is None + + async def test_create_request_info_dynamo_session_control_uses_x_correlation_id( + self, mock_worker + ): + """Dynamo session_control uses the request's X-Correlation-ID.""" + mock_worker.model_endpoint.endpoint.use_dynamo_conv_aware_routing = True + marker = "[rid:stable]\n\n" + + warmup_credit = _make_credit( + x_correlation_id="session-xcorr", + turn_index=3, + phase=CreditPhase.WARMUP, + cache_bust_marker=marker, + ) + warmup_request = mock_worker._create_request_info( + x_request_id="req-warmup", + credit_context=_make_credit_context(warmup_credit), + session=_make_session(x_correlation_id="session-xcorr", turn_index=3), + ) + warmup_session_control = warmup_request.turns[-1].extra_body["nvext"][ + "session_control" + ] + + profile_credit = _make_credit( + x_correlation_id="session-xcorr", + turn_index=4, + phase=CreditPhase.PROFILING, + cache_bust_marker=marker, + ) + profile_request = mock_worker._create_request_info( + x_request_id="req-profile", + credit_context=_make_credit_context(profile_credit), + session=_make_session(x_correlation_id="session-xcorr", turn_index=4), + ) + profile_session_control = profile_request.turns[-1].extra_body["nvext"][ + "session_control" + ] + + assert warmup_session_control == { + "session_id": "session-xcorr", + "timeout": 300, + "action": "bind", + } + assert "session-xcorr" in mock_worker._dynamo_bound_session_ids + assert profile_session_control == { + "session_id": "session-xcorr", + "timeout": 300, + } + + async def test_create_request_info_dynamo_session_control_ignores_parent_key( + self, mock_worker + ): + """Dynamo session identity follows the child credit, not sticky routing key.""" + mock_worker.model_endpoint.endpoint.use_dynamo_conv_aware_routing = True + credit = _make_credit( + x_correlation_id="child-xcorr", + parent_correlation_id="parent-xcorr", + cache_bust_marker="[rid:child]\n\n", + ) + + request = mock_worker._create_request_info( + x_request_id="req-child", + credit_context=_make_credit_context(credit), + session=_make_session(x_correlation_id="child-xcorr"), + ) + + assert request.turns[-1].extra_body["nvext"]["session_control"] == { + "session_id": "child-xcorr", + "timeout": 300, + "action": "bind", + } + + async def test_create_request_info_dynamo_session_control_preserves_extra_body( + self, mock_worker + ): + """Dynamo injection should merge with existing per-turn extra_body.""" + mock_worker.model_endpoint.endpoint.use_dynamo_conv_aware_routing = True + mock_worker.model_endpoint.endpoint.dynamo_session_timeout_seconds = 123 + turn = Turn( + role="user", + texts=[Text(contents=["hello"])], + extra_body={ + "temperature": 0, + "nvext": { + "trace": "keep", + "session_control": {"existing": "keep"}, + }, + }, + ) + credit = _make_credit(x_correlation_id="xcorr-raw") + + request_info = mock_worker._create_request_info( + x_request_id="req-merge", + credit_context=_make_credit_context(credit), + session=_make_session(x_correlation_id="xcorr-raw", turn=turn), + ) + + extra_body = request_info.turns[-1].extra_body + assert extra_body == { + "temperature": 0, + "nvext": { + "trace": "keep", + "session_control": { + "existing": "keep", + "session_id": "xcorr-raw", + "timeout": 123, + "action": "bind", + }, + }, + } + assert turn.extra_body["nvext"]["session_control"] == {"existing": "keep"} + + async def test_create_request_info_dynamo_session_control_patches_raw_payload( + self, mock_worker + ): + """Raw payload replay bypasses extra_body, so patch its payload directly.""" + mock_worker.model_endpoint.endpoint.use_dynamo_conv_aware_routing = True + turn = Turn( + role="user", + raw_payload={"messages": [{"role": "user", "content": "hi"}]}, + ) + credit = _make_credit(x_correlation_id="xcorr-raw") + + request_info = mock_worker._create_request_info( + x_request_id="req-raw", + credit_context=_make_credit_context(credit), + session=_make_session(x_correlation_id="xcorr-raw", turn=turn), + ) + + assert request_info.turns[-1].raw_payload == { + "messages": [{"role": "user", "content": "hi"}], + "nvext": { + "session_control": { + "session_id": "xcorr-raw", + "timeout": 300, + "action": "bind", + } + }, + } + assert turn.raw_payload == {"messages": [{"role": "user", "content": "hi"}]} + # --- FirstToken Callback Test Helpers --- From 8c9067d0bd8d13161606b2ab59307f896b82ae99 Mon Sep 17 00:00:00 2001 From: weireweire <20922698+weireweire@users.noreply.github.com> Date: Wed, 10 Jun 2026 16:53:01 +0800 Subject: [PATCH 2/2] Fix first-turn cache-bust marker injection --- src/aiperf/workers/worker.py | 32 +++++++++--- .../test_worker_cache_bust_injection.py | 50 +++++++++++++++++++ 2 files changed, 74 insertions(+), 8 deletions(-) diff --git a/src/aiperf/workers/worker.py b/src/aiperf/workers/worker.py index e54cffbf7e..4f762628ad 100644 --- a/src/aiperf/workers/worker.py +++ b/src/aiperf/workers/worker.py @@ -101,6 +101,22 @@ def _apply_cache_bust_to_system_message( return system_message +def _content_contains_marker(content: Any, marker: str) -> bool: + """Return whether a message/text payload already carries this marker.""" + marker_text = marker.strip() + markers = [value for value in (marker, marker_text) if value] + if isinstance(content, str): + return any(value in content for value in markers) + if isinstance(content, list): + for part in content: + if not isinstance(part, dict): + continue + text = part.get("text") + if isinstance(text, str) and any(value in text for value in markers): + return True + return False + + def _inject_marker_into_raw_messages( raw_messages: list[dict], marker: str, *, is_prefix: bool ) -> None: @@ -117,6 +133,8 @@ def _inject_marker_into_raw_messages( if not isinstance(first, dict) or first.get("role") != "system": return content = first.get("content", "") + if _content_contains_marker(content, marker): + return if isinstance(content, str): raw_messages[0] = { **first, @@ -149,6 +167,8 @@ def _inject_marker_into_first_user_turn( for idx, msg in enumerate(raw_messages): if isinstance(msg, dict) and msg.get("role") == "user": content = msg.get("content", "") + if _content_contains_marker(content, marker): + return if isinstance(content, str): raw_messages[idx] = { **msg, @@ -231,6 +251,8 @@ def _inject_marker_into_first_user_text( first.contents = [marker.strip()] return existing = first.contents[0] + if _content_contains_marker(existing, marker): + return first.contents[0] = (marker + existing) if is_prefix else (existing + marker) @@ -277,12 +299,7 @@ def _apply_cache_bust( ``system_message`` nor a leading ``role=="system"`` entry in any turn's ``raw_messages``), the marker is routed to the first user turn with the same prefix/suffix orientation — i.e. SYSTEM_PREFIX falls back to a - first-user-turn prefix, SYSTEM_SUFFIX falls back to a first-user-turn - suffix. Without a system prompt the first user message is the prefix of - the entire wire payload, so this produces the same physical token-0 - divergence without fabricating a system role. The fallback is gated on - ``credit.turn_index == 0`` (matches FIRST_TURN_* semantics: marker only - affects the first turn's KV cache; later turns inherit). + first-user-turn prefix, SYSTEM_SUFFIX falls back to a first-user-turn suffix. """ marker = credit.cache_bust_marker target = credit.cache_bust_target @@ -317,8 +334,7 @@ def _apply_cache_bust( _inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix) return system_message - if credit.turn_index == 0: - _inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix) + _inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix) return system_message diff --git a/tests/unit/workers/test_worker_cache_bust_injection.py b/tests/unit/workers/test_worker_cache_bust_injection.py index 93d9fd7ea4..c5f002bdfc 100644 --- a/tests/unit/workers/test_worker_cache_bust_injection.py +++ b/tests/unit/workers/test_worker_cache_bust_injection.py @@ -90,6 +90,18 @@ def test_inject_marker_into_raw_messages_prefix(): assert raw[0]["content"] == _PREFIX_MARKER + "you are helpful" +def test_inject_marker_into_raw_messages_prefix_is_idempotent(): + raw = [ + {"role": "system", "content": "you are helpful"}, + {"role": "user", "content": "hi"}, + ] + + _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_raw_messages(raw, _PREFIX_MARKER, is_prefix=True) + + assert raw[0]["content"] == _PREFIX_MARKER + "you are helpful" + + def test_inject_marker_into_raw_messages_suffix(): raw = [ {"role": "system", "content": "you are helpful"}, @@ -117,6 +129,15 @@ def test_inject_first_user_turn_prefix_with_system_present(): assert raw[1]["content"] == _PREFIX_MARKER + "hi" +def test_inject_first_user_turn_prefix_is_idempotent(): + raw = [{"role": "system", "content": "sys"}, {"role": "user", "content": "hi"}] + + _inject_marker_into_first_user_turn(raw, _PREFIX_MARKER, is_prefix=True) + _inject_marker_into_first_user_turn(raw, _PREFIX_MARKER, is_prefix=True) + + assert raw[1]["content"] == _PREFIX_MARKER + "hi" + + def test_inject_first_user_turn_suffix_user_only(): raw = [{"role": "user", "content": "hi"}] _inject_marker_into_first_user_turn(raw, _SUFFIX_MARKER, is_prefix=False) @@ -611,6 +632,35 @@ def test_apply_first_turn_prefix_under_deltas_injects_into_turn_0_user_role(): assert session.turn_list[1].raw_messages[1]["content"] == "follow up" +def test_apply_first_turn_prefix_under_deltas_mid_turn_marks_seeded_turn_0_once(): + """Agentic replay can start at turn_index>0 after seeding turns 0..k-1. + + FIRST_TURN_PREFIX must still attach to the seeded first user turn. Repeated + calls on the same mutable session should not duplicate the marker. + """ + turn_0 = [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": "hi"}, + ] + turn_1_delta = [ + {"role": "assistant", "content": "hello"}, + {"role": "user", "content": "follow up"}, + ] + session = _make_delta_session([turn_0, turn_1_delta]) + credit = _make_credit( + target=CacheBustTarget.FIRST_TURN_PREFIX, + marker=_PREFIX_MARKER, + turn_index=1, + num_turns=2, + ) + + _apply_cache_bust(session, credit, system_message=None) + _apply_cache_bust(session, credit, system_message=None) + + assert session.turn_list[0].raw_messages[1]["content"] == _PREFIX_MARKER + "hi" + assert session.turn_list[1].raw_messages[1]["content"] == "follow up" + + def test_apply_system_prefix_no_system_under_deltas_falls_back_to_turn_0_user(): """No system anywhere + delta-mode turn_list -> fallback marks turn 0 user only.""" turn_0 = [{"role": "user", "content": "hi"}]