Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/aiperf/common/config/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class EndpointDefaults:
CONNECTION_REUSE_STRATEGY = ConnectionReuseStrategy.POOLED
DOWNLOAD_VIDEO_CONTENT = False
REQUEST_CONTENT_TYPE = None
USE_DYNAMO_CONV_AWARE_ROUTING = False
DYNAMO_SESSION_TIMEOUT_SECONDS = 300
# Readiness probe defaults. Timeout 0 disables the probe (the default);
# any positive value enables it. Interval is only consulted when the
# probe is enabled but is validated positive so mis-configuration
Expand Down
34 changes: 34 additions & 0 deletions src/aiperf/common/config/endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,40 @@ def url(self) -> str:
),
] = EndpointDefaults.USE_SERVER_TOKEN_COUNT

use_dynamo_conv_aware_routing: Annotated[
bool,
Field(
description=(
"Emit Dynamo nvext.session_control in OpenAI-compatible request "
"bodies so Dynamo can bind all turns from the same replayed "
"conversation lineage to the same backend worker. This is only "
"intended for Dynamo frontends that implement session_control."
),
),
CLIParameter(
name=(
"--use-dynamo-conv-aware-routing",
"--use-dynamo-session-control",
),
group=Groups.ENDPOINT,
),
] = EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING

dynamo_session_timeout_seconds: Annotated[
int,
Field(
description=(
"Dynamo nvext.session_control timeout in seconds when "
"--use-dynamo-conv-aware-routing is enabled."
),
ge=1,
),
CLIParameter(
name=("--dynamo-session-timeout-seconds",),
group=Groups.ENDPOINT,
),
] = EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS

connection_reuse_strategy: Annotated[
ConnectionReuseStrategy,
Field(
Expand Down
15 changes: 15 additions & 0 deletions src/aiperf/common/models/model_endpoint_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ class EndpointInfo(AIPerfBaseModel):
default=EndpointDefaults.USE_SERVER_TOKEN_COUNT,
description="Use server-reported token counts from API usage fields instead of client-side tokenization.",
)
use_dynamo_conv_aware_routing: bool = Field(
default=EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING,
description="Emit Dynamo nvext.session_control for conversation-aware routing.",
)
dynamo_session_timeout_seconds: int = Field(
default=EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS,
ge=1,
description="Timeout in seconds for Dynamo nvext.session_control sessions.",
)
connection_reuse_strategy: ConnectionReuseStrategy = Field(
default=EndpointDefaults.CONNECTION_REUSE_STRATEGY,
description="Transport connection reuse strategy.",
Expand Down Expand Up @@ -164,6 +173,12 @@ def from_user_config(cls, user_config: UserConfig) -> "EndpointInfo":
api_key=user_config.endpoint.api_key,
use_legacy_max_tokens=user_config.endpoint.use_legacy_max_tokens,
use_server_token_count=user_config.endpoint.use_server_token_count,
use_dynamo_conv_aware_routing=(
user_config.endpoint.use_dynamo_conv_aware_routing
),
dynamo_session_timeout_seconds=(
user_config.endpoint.dynamo_session_timeout_seconds
),
connection_reuse_strategy=user_config.endpoint.connection_reuse_strategy,
download_video_content=user_config.endpoint.download_video_content,
request_content_type=user_config.endpoint.request_content_type,
Expand Down
125 changes: 116 additions & 9 deletions src/aiperf/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
import time
import uuid
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import orjson

Expand Down Expand Up @@ -101,6 +101,22 @@ def _apply_cache_bust_to_system_message(
return system_message


def _content_contains_marker(content: Any, marker: str) -> bool:
"""Return whether a message/text payload already carries this marker."""
marker_text = marker.strip()
markers = [value for value in (marker, marker_text) if value]
if isinstance(content, str):
return any(value in content for value in markers)
if isinstance(content, list):
for part in content:
if not isinstance(part, dict):
continue
text = part.get("text")
if isinstance(text, str) and any(value in text for value in markers):
return True
return False


def _inject_marker_into_raw_messages(
raw_messages: list[dict], marker: str, *, is_prefix: bool
) -> None:
Expand All @@ -117,6 +133,8 @@ def _inject_marker_into_raw_messages(
if not isinstance(first, dict) or first.get("role") != "system":
return
content = first.get("content", "")
if _content_contains_marker(content, marker):
return
if isinstance(content, str):
raw_messages[0] = {
**first,
Expand Down Expand Up @@ -149,6 +167,8 @@ def _inject_marker_into_first_user_turn(
for idx, msg in enumerate(raw_messages):
if isinstance(msg, dict) and msg.get("role") == "user":
content = msg.get("content", "")
if _content_contains_marker(content, marker):
return
if isinstance(content, str):
raw_messages[idx] = {
**msg,
Expand Down Expand Up @@ -231,6 +251,8 @@ def _inject_marker_into_first_user_text(
first.contents = [marker.strip()]
return
existing = first.contents[0]
if _content_contains_marker(existing, marker):
return
first.contents[0] = (marker + existing) if is_prefix else (existing + marker)


Expand Down Expand Up @@ -277,12 +299,7 @@ def _apply_cache_bust(
``system_message`` nor a leading ``role=="system"`` entry in any turn's
``raw_messages``), the marker is routed to the first user turn with the
same prefix/suffix orientation — i.e. SYSTEM_PREFIX falls back to a
first-user-turn prefix, SYSTEM_SUFFIX falls back to a first-user-turn
suffix. Without a system prompt the first user message is the prefix of
the entire wire payload, so this produces the same physical token-0
divergence without fabricating a system role. The fallback is gated on
``credit.turn_index == 0`` (matches FIRST_TURN_* semantics: marker only
affects the first turn's KV cache; later turns inherit).
first-user-turn prefix, SYSTEM_SUFFIX falls back to a first-user-turn suffix.
"""
marker = credit.cache_bust_marker
target = credit.cache_bust_target
Expand Down Expand Up @@ -317,8 +334,7 @@ def _apply_cache_bust(
_inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix)
return system_message

if credit.turn_index == 0:
_inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix)
_inject_marker_at_first_user(session.turn_list, marker, is_prefix=is_prefix)
return system_message


Expand Down Expand Up @@ -460,6 +476,9 @@ def __init__(
# high concurrency — the misconfiguration is the same for every credit.
self._cache_bust_warning_shown: bool = False

# Worker-local bind cache for Dynamo conversation-aware routing.
self._dynamo_bound_session_ids: set[str] = set()

# Only used as a fallback when dataset client is not initialized
# or was not available when the credit was dropped. Must be created here
# so it can be attached to the worker lifecycle.
Expand Down Expand Up @@ -938,6 +957,11 @@ def _create_request_info(
credit = credit_context.credit
if turns is None:
turns = session.turn_list if session else []
turns, payload_bytes = self._add_dynamo_session_control(
credit=credit,
turns=turns,
payload_bytes=payload_bytes,
)
return RequestInfo(
model_endpoint=self.model_endpoint,
credit_num=credit.id,
Expand Down Expand Up @@ -967,6 +991,89 @@ def _create_request_info(
else None,
)

def _add_dynamo_session_control(
self,
*,
credit: Credit,
turns: list[Turn],
payload_bytes: bytes | None,
) -> tuple[list[Turn], bytes | None]:
"""Inject Dynamo ``nvext.session_control`` when configured.

The normal chat path merges ``Turn.extra_body`` into the wire payload.
Raw-payload paths bypass that formatter, so they are patched directly
when possible. ``payload_bytes`` is decoded only under the opt-in flag.
"""
endpoint = self.model_endpoint.endpoint
if not endpoint.use_dynamo_conv_aware_routing:
return turns, payload_bytes

session_control = self._dynamo_session_control_for_credit(credit)
if payload_bytes is not None:
payload = orjson.loads(payload_bytes)
if not isinstance(payload, dict):
raise ValueError("Dynamo session_control requires object payload_bytes")
payload = self._merge_dynamo_session_control(payload, session_control)
if session_control.get("action") == "bind":
self._dynamo_bound_session_ids.add(session_control["session_id"])
return turns, orjson.dumps(payload)

if not turns:
return turns, payload_bytes

last_turn = turns[-1]
updates: dict[str, Any]
if last_turn.raw_payload is not None:
raw_payload = self._merge_dynamo_session_control(
last_turn.raw_payload,
session_control,
)
updates = {"raw_payload": raw_payload}
else:
updates = {
"extra_body": self._merge_dynamo_session_control(
last_turn.extra_body or {},
session_control,
)
}

new_turns = list(turns)
new_turns[-1] = last_turn.model_copy(update=updates)
if session_control.get("action") == "bind":
self._dynamo_bound_session_ids.add(session_control["session_id"])
return new_turns, payload_bytes

def _dynamo_session_control_for_credit(self, credit: Credit) -> dict[str, Any]:
session_id = self._dynamo_session_id(credit)
session_control: dict[str, Any] = {
"session_id": session_id,
"timeout": self.model_endpoint.endpoint.dynamo_session_timeout_seconds,
}
if session_id not in self._dynamo_bound_session_ids:
session_control["action"] = "bind"
return session_control

@staticmethod
def _dynamo_session_id(credit: Credit) -> str:
return credit.x_correlation_id

@staticmethod
def _merge_dynamo_session_control(
payload: dict[str, Any],
session_control: dict[str, Any],
) -> dict[str, Any]:
merged = dict(payload)
raw_nvext = merged.get("nvext")
nvext = dict(raw_nvext) if isinstance(raw_nvext, dict) else {}
raw_session_control = nvext.get("session_control")
merged_session_control = (
dict(raw_session_control) if isinstance(raw_session_control, dict) else {}
)
merged_session_control.update(session_control)
nvext["session_control"] = merged_session_control
merged["nvext"] = nvext
return merged

async def _retrieve_conversation(
self,
*,
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/common/config/test_endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def test_endpoint_config_defaults():
assert config.custom_endpoint == EndpointDefaults.CUSTOM_ENDPOINT
assert config.streaming == EndpointDefaults.STREAMING
assert config.url == EndpointDefaults.URL
assert (
config.use_dynamo_conv_aware_routing
== EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING
)
assert (
config.dynamo_session_timeout_seconds
== EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS
)


def test_endpoint_config_custom_values():
Expand All @@ -49,6 +57,8 @@ def test_endpoint_config_custom_values():
"urls": ["http://custom-url"],
"timeout_seconds": 10,
"api_key": "custom_api_key",
"use_dynamo_conv_aware_routing": True,
"dynamo_session_timeout_seconds": 123,
}
config = EndpointConfig(**custom_values)
for key, value in custom_values.items():
Expand Down
25 changes: 23 additions & 2 deletions tests/unit/common/models/test_endpoint_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pytest

from aiperf.common.config import EndpointDefaults
from aiperf.common.models.model_endpoint_info import EndpointInfo
from aiperf.common.config import EndpointConfig, EndpointDefaults, UserConfig
from aiperf.common.models.model_endpoint_info import EndpointInfo, ModelEndpointInfo


class TestEndpointInfoMultiURL:
Expand All @@ -16,6 +16,14 @@ def test_single_url_default(self):
info = EndpointInfo()
assert info.base_urls == [EndpointDefaults.URL]
assert info.base_url == EndpointDefaults.URL
assert (
info.use_dynamo_conv_aware_routing
== EndpointDefaults.USE_DYNAMO_CONV_AWARE_ROUTING
)
assert (
info.dynamo_session_timeout_seconds
== EndpointDefaults.DYNAMO_SESSION_TIMEOUT_SECONDS
)

def test_single_url_custom(self):
"""Custom single URL should work."""
Expand All @@ -35,6 +43,19 @@ def test_base_urls_must_have_at_least_one(self):
with pytest.raises(ValueError):
EndpointInfo(base_urls=[])

def test_dynamo_session_control_from_user_config(self):
"""Dynamo session-control fields should flow into runtime endpoint info."""
user_config = UserConfig(
endpoint=EndpointConfig(
model_names=["test-model"],
use_dynamo_conv_aware_routing=True,
dynamo_session_timeout_seconds=123,
)
)
info = ModelEndpointInfo.from_user_config(user_config).endpoint
assert info.use_dynamo_conv_aware_routing is True
assert info.dynamo_session_timeout_seconds == 123


class TestEndpointInfoGetUrl:
"""Tests for EndpointInfo.get_url() method."""
Expand Down
Loading
Loading