Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 33 additions & 17 deletions src/aiperf/timing/branch_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def __init__(
# early when one branch completes before another branch's spawning
# turn has been reached.
self._gated_turn_prereq_keys: dict[tuple[str, int], set[str]] = {}
# (conv_id, gated_turn_idx, prereq_key) -> spawning turn index.
# Snapshot seeding consults this to tell prereqs whose spawning turn
# already fired before t* (their children either appear live in the
# snapshot or completed entirely pre-t*) apart from prereqs whose
# spawning turn will fire during replay.
self._prereq_spawning_turn: dict[tuple[str, int, str], int] = {}
# Defense-in-depth duplicate detection against future loaders that
# bypass ``validate_for_orchestrator_v1``. A given
# ``(branch_id, gated_turn_idx)`` tuple must not appear twice — that
Expand Down Expand Up @@ -294,6 +300,9 @@ def _build_prereq_index(self) -> None:
self._gated_turn_prereq_keys.setdefault(
(conv.conversation_id, gated_idx), set()
).add(prereq_key)
self._prereq_spawning_turn[
(conv.conversation_id, gated_idx, prereq_key)
] = spawning_idx

def get_branch_ids(self, credit) -> list[str]:
"""Look up the completed turn's ``branch_ids`` from metadata.
Expand Down Expand Up @@ -514,7 +523,19 @@ def _ensure_seeded_join(
for prereq_key in self._gated_turn_prereq_keys.get(
(parent_state.conversation_id, gated_idx), set()
):
pending.outstanding[prereq_key] = PrereqState()
state = PrereqState()
spawning_idx = self._prereq_spawning_turn.get(
(parent_state.conversation_id, gated_idx, prereq_key)
)
if spawning_idx is not None and spawning_idx < parent_state.next_turn_index:
# The spawning turn fired before t* and will never replay.
# Children still alive at t* re-register with expected
# counts during this same seeding pass; a branch with no
# live children completed entirely pre-t* and must seed as
# satisfied, or the gate is permanently unsatisfiable and
# the parent lane silently wedges for the whole phase.
state.registered = True
pending.outstanding[prereq_key] = state

if (
parent_state.waiting_on_children
Expand Down Expand Up @@ -762,22 +783,22 @@ async def _spawn_children_and_register_gates(
self.stats.children_errored += 1
self.stats.children_spawned -= 1

# If no children at all landed (all failed), check for gates that
# are now zero-outstanding and dispatch the gated turn immediately
# to avoid hanging the parent.
# If no children at all landed (all failed), pop gates that are now
# zero-outstanding so the parent is not left suspended on a join
# that can never fire via the child-leaf decrement path.
gates_for_parent = self._future_joins.get(parent_corr, {})
drained_gates: list[PendingBranchJoin] = []
for gated_idx, pending in list(gates_for_parent.items()):
# A gate may be vestigial (created this call and immediately
# satisfied) if every child under every prereq rolled back.
if pending.is_satisfied:
# Only fire NOW if the gate is the parent's IMMEDIATE next
# turn. A delayed gate (intervening turns precede it) must not
# dispatch out of order: pop it silently and let the parent
# advance turns normally; when it reaches the (now un-gated)
# turn, the strategy sends it as an ordinary continuation.
if gated_idx == credit.turn_index + 1:
drained_gates.append(pending)
# Pop silently regardless of position. With the gate gone,
# _maybe_suspend_parent returns False and the strategy's
# normal continuation dispatches the (now un-gated) turn as
# an ordinary next turn - exactly once. Dispatching the
# immediate-next gate here as well double-dispatched the
# same turn: intercept still returned False, so the callback
# handler fell through to handle_credit_return ->
# _dispatch_next_turn for the identical turn_index.
self._pop_future_join(parent_corr, gated_idx)
# If no successful children AND no gated turns, release the
# reserved parent state so the parent can drain.
Expand All @@ -799,11 +820,6 @@ async def _spawn_children_and_register_gates(
del self._descendant_counts[parent_corr]
self._notify_drain() # all-children-rolled-back path: no credit return follows

for pending in drained_gates:
# Zero-outstanding gate with no way to fire via child-leaf
# decrement: dispatch immediately (matches Phase 0 hang-fix).
await self._release_blocked_join(pending)

def _ensure_future_join(
self,
credit,
Expand Down
156 changes: 98 additions & 58 deletions src/aiperf/timing/strategies/agentic_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,18 @@ def __init__(
# Cache-bust state. WARMUP and PROFILING construct distinct strategy
# instances (PhaseRunner builds a fresh AgenticReplayStrategy per
# phase), while the shared TrajectorySource keeps each sampled lane's
# x_correlation_id stable across the phase boundary. The MARKER text is
# also warmup-coherent: the digest is computed from
# ``(benchmark_id, recycle_pass, trajectory_index, trace_id)`` —
# phase-agnostic — so warmup turn k_i and profile turn k_i+1 get
# the same marker within the continued session. That preserves the
# KV-cache lineage warmup is meant to prime.
# trajectory_index is stable per "lane" (slot in the trajectory list)
# and reused on recycle, so the digest changes only across recycle
# passes for a given trace_id.
self._recycle_pass: dict[str, int] = {}
self._session_marker: dict[str, str | None] = {}
# x_correlation_id stable across the phase boundary AND carries the
# marker ledger across it. A session continuing into PROFILING reuses
# the exact marker minted for it during WARMUP (see
# ``_mint_marker_for_session``), so warmup turn k_i and profile turn
# k_i+1 share the same marker within the continued session - the
# KV-cache lineage warmup is meant to prime is preserved by identity,
# not by replaying mint order. New sessions draw from the shared
# ``recycle_pass`` counter, which never restarts, so a recycled
# session's digest can never collide with a warmed one.
ledger = conversation_source.cache_bust_ledger
self._recycle_pass: dict[str, int] = ledger.recycle_pass
self._session_marker: dict[str, str | None] = ledger.session_marker
self._correlation_to_lane: dict[str, int] = {}
self._cache_bust_target: CacheBustTarget = (
user_config.input.prompt.cache_bust.target
Expand Down Expand Up @@ -176,10 +177,12 @@ async def setup_phase(self) -> None:
TimingManager construction time.

PROFILING: build the FIFO recycle queue with the FULL set of loader
trace_ids (including trajectory ids). Trajectories run live at
PROFILING start (resumed at k_i+1); the pop loop in
``_spawn_from_recycle_or_id`` skips trace_ids whose session is
currently active so we never spawn a duplicate concurrent session.
trace_ids (including trajectory ids), and pre-register every live
trajectory lane in ``_active_traces``. Trajectories run live at
PROFILING start (resumed at k_i+1); pre-registering them here -
rather than lane-by-lane during dispatch - means a lane that
recycles immediately at startup can never pop a trace whose own
lane simply hasn't dispatched yet (a duplicate concurrent session).
"""
if self.config.phase == CreditPhase.PROFILING:
if not self.conversation_source.trajectories:
Expand All @@ -188,6 +191,7 @@ async def setup_phase(self) -> None:
"WARMUP must complete with at least one trajectory before "
"PROFILING can start. Check loader output and warmup failures."
)
self._active_traces.update(self._lanes_per_trace)
self._recycle_queue = asyncio.Queue()
# Recycle pool spans the FULL dataset, not (full - trajectories).
# Trajectories run live at PROFILING start (resumed at k_i+1) and
Expand Down Expand Up @@ -260,43 +264,75 @@ async def _execute_warmup(self) -> None:
async def _execute_profiling(self) -> None:
"""Resume each trajectory at ``k_i + 1`` to seed the steady state.

Subsequent turns and recycle-pool sessions are dispatched from
handle_credit_return.
All trajectories are dispatched concurrently so the full concurrency
target is reached as fast as slot limits allow, rather than
serializing over N credit round-trips. Subsequent turns and
recycle-pool sessions are dispatched from handle_credit_return.
"""
self.info(
f"PROFILING execute: resuming {len(self.conversation_source.trajectories)} "
f"trajectory sessions"
)
for lane, trajectory in enumerate(self.conversation_source.trajectories):
if trajectory.snapshot is not None:
await self._dispatch_snapshot_for_profiling(trajectory, lane)
# return_exceptions=True keeps ownership of every lane until it
# settles: a bare gather would re-raise the first failure while the
# sibling coroutines keep issuing credits into a failing phase,
# unreachable by the phase runner's cancellation.
results = await asyncio.gather(
*(
self._dispatch_one_profiling_trajectory(trajectory, lane)
for lane, trajectory in enumerate(self.conversation_source.trajectories)
),
return_exceptions=True,
)
first_error: BaseException | None = None
for lane, result in enumerate(results):
if not isinstance(result, BaseException):
continue

session = self.conversation_source.session_for(trajectory)
self._correlation_to_lane[session.x_correlation_id] = lane
self._active_traces[trajectory.conversation_id] += 1
self._mint_marker_for_session(
session.x_correlation_id, trajectory.conversation_id, lane
trace_id = self.conversation_source.trajectories[lane].conversation_id
self.error(
f"PROFILING dispatch failed for lane {lane} "
f"(trace_id={trace_id!r}): {result!r}"
)
resume_index = trajectory.start_turn_index + 1
num_turns = len(session.metadata.turns)

if resume_index >= num_turns:
# Trajectory's k_i was already the last turn (rare: happens
# only for very short traces). Skip directly to recycle.
self.debug(
lambda cid=trajectory.conversation_id,
k=trajectory.start_turn_index,
n=num_turns: f"Trajectory {cid} k_i={k} >= last turn (n={n}); recycling immediately"
)
await self._spawn_from_recycle_or_id(
trajectory.conversation_id,
finished_correlation_id=session.x_correlation_id,
if first_error is None:
first_error = result
if first_error is not None:
raise first_error

async def _dispatch_one_profiling_trajectory(
self, trajectory: Trajectory, lane: int
) -> None:
"""Dispatch one lane's initial PROFILING credit (run under gather)."""
if trajectory.snapshot is not None:
await self._dispatch_snapshot_for_profiling(trajectory, lane)
return

session = self.conversation_source.session_for(trajectory)
self._correlation_to_lane[session.x_correlation_id] = lane
# The lane's trace was pre-registered in _active_traces by setup_phase.
self._mint_marker_for_session(
session.x_correlation_id, trajectory.conversation_id, lane
)
resume_index = trajectory.start_turn_index + 1
num_turns = len(session.metadata.turns)

if resume_index >= num_turns:
# Trajectory's k_i was already the last turn (rare: happens
# only for very short traces). Skip directly to recycle.
self.debug(
lambda: (
f"Trajectory {trajectory.conversation_id} "
f"k_i={trajectory.start_turn_index} >= last turn "
f"(n={num_turns}); recycling immediately"
)
continue
)
await self._spawn_from_recycle_or_id(
trajectory.conversation_id,
finished_correlation_id=session.x_correlation_id,
)
return

turn = self._build_turn_for_session(session, resume_index)
await self.credit_issuer.issue_credit(turn)
turn = self._build_turn_for_session(session, resume_index)
await self.credit_issuer.issue_credit(turn)

async def handle_credit_return(
self, credit: Credit, *, error: str | None = None
Expand Down Expand Up @@ -396,8 +432,9 @@ async def _spawn_from_recycle_or_id(

The initial recycle queue spans the full dataset pool (including
trajectory trace_ids whose sessions are running live at PROFILING
start). The pop loop skips trace_ids in ``_active_traces`` and
re-enqueues them to avoid duplicate concurrent sessions.
start; every live lane is pre-registered in ``_active_traces`` by
``setup_phase``). The pop loop skips trace_ids in ``_active_traces``
and re-enqueues them to avoid duplicate concurrent sessions.
"""
# Prune unconditionally so every early-return path leaves dicts clean.
self._session_marker.pop(finished_correlation_id, None)
Expand Down Expand Up @@ -452,10 +489,10 @@ async def _dispatch_snapshot_for_profiling(
) -> None:
warmup_snapshot = self._get_snapshot(trajectory)
snapshot = self._snapshot_continuation_after_warmup(trajectory)
# Each lane's single root session (continuing or terminal) was
# pre-registered in _active_traces by setup_phase.
for state in snapshot.states:
self._correlation_to_lane[state.x_correlation_id] = lane
if state.agent_depth == 0:
self._active_traces[state.conversation_id] += 1
self._mint_marker_for_session(
state.x_correlation_id, state.conversation_id, lane
)
Expand All @@ -472,7 +509,6 @@ async def _dispatch_snapshot_for_profiling(
]
for state in terminal_roots:
self._correlation_to_lane[state.x_correlation_id] = lane
self._active_traces[state.conversation_id] += 1
self._mint_marker_for_session(
state.x_correlation_id, state.conversation_id, lane
)
Expand Down Expand Up @@ -647,22 +683,26 @@ def _build_turn_for_session(
def _mint_marker_for_session(
self, x_correlation_id: str, trace_id: str, trajectory_index: int
) -> str | None:
"""Mint and store a per-session cache-bust marker.
"""Mint (or reuse) and store a per-session cache-bust marker.

Returns None when the feature is disabled (target=NONE), in which
case the session map records None so callers can unconditionally
look it up. Increments _recycle_pass[trace_id] each time a new
session is minted for the same trace_id, so digest rotates across
recycles within a single phase.

The strategy is constructed FRESH for each phase (per the
TimingStrategyProtocol contract; PhaseRunner builds a new instance for
WARMUP and another for PROFILING). Both phases start with empty
``_recycle_pass``, so the first mint for a given trace_id in PROFILING
produces ``pass=0`` — matching WARMUP's pass=0 digest for the same
(trace_id, lane) pair. The shared ``TrajectorySource`` also preserves
the lane's x_correlation_id across the phase boundary.
recycles.

Both ``_session_marker`` and ``_recycle_pass`` live on the shared
``TrajectorySource`` ledger, surviving the WARMUP -> PROFILING
boundary (strategies are constructed fresh per phase). A session
whose x_correlation_id was already minted - a continuing lane
resuming at k_i+1 - keeps its WARMUP marker verbatim instead of
re-minting, so a continued session's digest can never rotate at the
phase boundary regardless of mint order. The pass counter never
restarts, so fresh sessions (recycles, parents unblocked after t*)
can never collide with a warmed digest.
"""
if x_correlation_id in self._session_marker:
return self._session_marker[x_correlation_id]
if self._cache_bust_target == CacheBustTarget.NONE:
self._session_marker[x_correlation_id] = None
return None
Expand Down
31 changes: 31 additions & 0 deletions src/aiperf/timing/trajectory_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,24 @@ class Trajectory:
)


@dataclass(slots=True)
class CacheBustLedger:
"""Cross-phase cache-bust marker state.

Lives on the shared ``TrajectorySource`` (constructed once at
TimingManager level) so the WARMUP and PROFILING strategy instances see
one ledger. A session that continues across the phase boundary keeps the
exact marker minted for it during WARMUP, and new sessions draw pass
numbers from a counter that never restarts - so a recycled session's
digest can never collide with a warmed one.
"""

recycle_pass: dict[str, int] = field(default_factory=dict)
"""Next-pass counter per trace_id; incremented on every fresh mint."""
session_marker: dict[str, str | None] = field(default_factory=dict)
"""Minted marker per live x_correlation_id (None when cache-bust is off)."""


@dataclass(slots=True, frozen=True)
class _BranchRuntime:
branch_id: str
Expand Down Expand Up @@ -189,6 +207,19 @@ def __init__(

self._log_trajectory_summary()

@property
def cache_bust_ledger(self) -> CacheBustLedger:
"""Marker ledger shared by the WARMUP and PROFILING strategy instances.

Created lazily so sources built through ``__new__`` in tests get a
ledger on first access without extra setup.
"""
ledger = getattr(self, "_cache_bust_ledger", None)
if ledger is None:
ledger = CacheBustLedger()
self._cache_bust_ledger = ledger
return ledger

def _log_trajectory_summary(self) -> None:
"""Log a one-block table of every trajectory's start position.

Expand Down
6 changes: 4 additions & 2 deletions tests/component_integration/test_agentic_replay_cli_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,10 @@ def test_agentic_replay_cli_scenario_unsafe_override_runs_to_completion(
"validator must log timing_mode auto-set under --scenario "
"(covers the read-only-property setter path against real UserConfig)"
)
assert "auto-set --inter-turn-delay-cap-seconds=60.0" in log_text, (
"validator must auto-set inter-turn-delay-cap when unset"
assert "auto-set --trace-idle-gap-cap-seconds=60.0" in log_text, (
"validator must auto-set the per-trace idle-gap cap when unset "
"(the AgentX scenario locks trace_idle_gap_cap_seconds, not the "
"inter-turn delay cap, since 932b4bc)"
)

assert result.json is not None, "JSON export must exist"
Expand Down
Loading
Loading