Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions crates/adaptive/src/agent_context_intercept.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! Opt-in request intercept for copying scope-local agent context into LLM requests.

use std::sync::Arc;

use nemo_relay::api::llm::LlmRequest;
use nemo_relay::api::runtime::LlmRequestInterceptFn;
use nemo_relay::codec::request::AnnotatedLlmRequest;
use serde_json::Value as Json;

use crate::config::AgentContextComponentConfig;
use crate::context_helpers::resolve_agent_context;

/// Opt-in LLM request intercept that injects canonical agent context into the request body.
pub struct AgentContextIntercept {
inject_body_path: String,
}

impl AgentContextIntercept {
/// Creates a new agent-context request intercept from component config.
pub fn new(config: AgentContextComponentConfig) -> Self {
Self {
inject_body_path: config.inject_body_path,
}
}

/// Converts this intercept into an [`LlmRequestInterceptFn`] suitable for registration.
pub fn into_request_fn(self) -> LlmRequestInterceptFn {
let inject_body_path = self.inject_body_path;
Arc::new(
move |_name: &str, mut request: LlmRequest, annotated: Option<AnnotatedLlmRequest>| {
if let Some(agent_context) = resolve_agent_context() {
insert_json_path_if_absent(
&mut request.content,
&inject_body_path,
&agent_context,
);
}
Ok((request, annotated))
},
)
}
}

fn insert_json_path_if_absent(root: &mut Json, path: &str, value: &Json) {
let parts = path
.split('.')
.filter(|part| !part.is_empty())
.collect::<Vec<_>>();
insert_json_parts_if_absent(root, &parts, value);
}

fn insert_json_parts_if_absent(root: &mut Json, parts: &[&str], value: &Json) {
let Some((head, tail)) = parts.split_first() else {
return;
};
let Some(object) = root.as_object_mut() else {
return;
};
if tail.is_empty() {
object
.entry((*head).to_string())
.or_insert_with(|| value.clone());
return;
}
let child = object
.entry((*head).to_string())
.or_insert_with(|| Json::Object(serde_json::Map::new()));
insert_json_parts_if_absent(child, tail, value);
}

#[cfg(test)]
#[path = "../tests/unit/agent_context_intercept_tests.rs"]
mod tests;
47 changes: 47 additions & 0 deletions crates/adaptive/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ pub struct AdaptiveConfig {
/// Built-in LLM hint injection settings.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub adaptive_hints: Option<AdaptiveHintsComponentConfig>,
/// Built-in agent context propagation settings.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub agent_context: Option<AgentContextComponentConfig>,
/// Built-in tool scheduling settings.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_parallelism: Option<ToolParallelismComponentConfig>,
Expand All @@ -45,6 +48,7 @@ impl Default for AdaptiveConfig {
state: None,
telemetry: None,
adaptive_hints: None,
agent_context: None,
tool_parallelism: None,
acg: None,
policy: ConfigPolicy::default(),
Expand Down Expand Up @@ -136,6 +140,30 @@ impl Default for AdaptiveHintsComponentConfig {
}
}

/// Typed helper for agent context propagation settings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentContextComponentConfig {
/// Intercept priority. Lower values run first.
#[serde(default = "default_priority")]
pub priority: i32,
/// Whether later request intercepts should be skipped after this one runs.
#[serde(default)]
pub break_chain: bool,
/// JSON path used when injecting request-body agent context.
#[serde(default = "default_agent_context_path")]
pub inject_body_path: String,
}

impl Default for AgentContextComponentConfig {
fn default() -> Self {
Self {
priority: default_priority(),
break_chain: false,
inject_body_path: default_agent_context_path(),
}
}
}

/// Typed helper for tool parallelism settings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParallelismComponentConfig {
Expand Down Expand Up @@ -200,6 +228,10 @@ fn default_adaptive_hints_path() -> String {
"nvext.agent_hints".to_string()
}

fn default_agent_context_path() -> String {
"nvext.agent_context".to_string()
}

fn default_tool_parallelism_mode() -> String {
"observe_only".to_string()
}
Expand Down Expand Up @@ -240,6 +272,13 @@ nemo_relay::editor_config! {
nested: AdaptiveHintsComponentConfig,
default: AdaptiveHintsComponentConfig,
},
agent_context => {
label: "agent_context",
kind: Section,
optional: true,
nested: AgentContextComponentConfig,
default: AgentContextComponentConfig,
},
tool_parallelism => {
label: "tool_parallelism",
kind: Section,
Expand Down Expand Up @@ -297,6 +336,14 @@ nemo_relay::editor_config! {
}
}

nemo_relay::editor_config! {
impl AgentContextComponentConfig {
priority => { label: "priority", kind: Integer },
break_chain => { label: "break_chain", kind: Boolean },
inject_body_path => { label: "inject_body_path", kind: String },
}
}

nemo_relay::editor_config! {
impl ToolParallelismComponentConfig {
priority => { label: "priority", kind: Integer },
Expand Down
30 changes: 30 additions & 0 deletions crates/adaptive/src/context_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//! - [`extract_scope_path`]: collects function names from the scope stack for trie lookup
//! - [`read_manual_latency_sensitivity`]: walks all scopes for manual `latency_sensitive` annotations
//! - [`resolve_agent_id`]: returns the first Agent scope name from the scope stack
//! - [`resolve_agent_context`]: returns the nearest scope-local agent context
//!
//! All functions are safe to call from sync contexts (intercepts are sync closures).
//! They acquire a read lock on the scope stack, which is always fast.
Expand All @@ -20,11 +21,15 @@

use nemo_relay::api::runtime::current_scope_stack;
use nemo_relay::api::scope::ScopeType;
use serde_json::Value as Json;
use uuid::Uuid;

/// Metadata key path for manual latency sensitivity annotation.
pub const LATENCY_SENSITIVITY_POINTER: &str = "/nemo_relay_adaptive/latency_sensitivity";

/// Metadata key path for the canonical agent context object.
pub const AGENT_CONTEXT_POINTER: &str = "/nemo_relay/agent_context";

/// Session-local scope identity used to coordinate warm-first cohorts.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SharedParentScopeIdentity {
Expand Down Expand Up @@ -169,6 +174,31 @@ pub fn resolve_agent_id() -> Option<String> {
.map(|s| s.name.clone())
}

/// Resolves the nearest canonical agent context from the current scope stack.
///
/// Producers attach this object to scope metadata at
/// `/nemo_relay/agent_context`. Request intercepts read the nearest active
/// value, so child agent scopes override their parent turn context.
///
/// # Returns
/// A cloned JSON object when one is visible on the current scope stack.
/// Returns `None` when no context exists or the scope stack cannot be read.
pub fn resolve_agent_context() -> Option<Json> {
let stack_handle = current_scope_stack();
let stack = match stack_handle.read() {
Ok(s) => s,
Err(_) => return None,
};
stack.scopes().iter().rev().find_map(|scope| {
scope
.metadata
.as_ref()
.and_then(|metadata| metadata.pointer(AGENT_CONTEXT_POINTER))
.filter(|value| value.is_object())
.cloned()
})
}

/// Resolves the session-local identity used by warm-first cohort coordination.
///
/// The shared parent must come from the parent scope, not the current scope's
Expand Down
25 changes: 21 additions & 4 deletions crates/adaptive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod acg_component;
pub mod acg_learner;
pub mod acg_profile;
pub mod adaptive_hints_intercept;
pub mod agent_context_intercept;
pub mod cache_diagnostics;
pub mod config;
pub mod context_helpers;
Expand All @@ -36,12 +37,13 @@ pub mod trie;
pub mod types;

pub use config::{
AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, BackendSpec, StateConfig,
TelemetryComponentConfig, ToolParallelismComponentConfig,
AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, AgentContextComponentConfig,
BackendSpec, StateConfig, TelemetryComponentConfig, ToolParallelismComponentConfig,
};
pub use context_helpers::{
LATENCY_SENSITIVITY_POINTER, extract_scope_path, read_manual_latency_sensitivity,
resolve_agent_id, resolve_shared_parent_scope_identity, set_latency_sensitivity,
AGENT_CONTEXT_POINTER, LATENCY_SENSITIVITY_POINTER, extract_scope_path,
read_manual_latency_sensitivity, resolve_agent_context, resolve_agent_id,
resolve_shared_parent_scope_identity, set_latency_sensitivity,
};
pub use error::{AdaptiveError, Result};
#[cfg(feature = "redis-backend")]
Expand All @@ -50,3 +52,18 @@ pub use runtime::features::AdaptiveRuntime;
pub use storage::erased::AnyBackend;
pub use storage::memory::InMemoryBackend;
pub use storage::traits::{StorageBackend, StorageBackendDyn};

#[cfg(test)]
pub(crate) mod test_support {
use tokio::sync::{Mutex, MutexGuard};

static GLOBAL_RUNTIME_MUTEX: Mutex<()> = Mutex::const_new(());

pub(crate) async fn lock_global_runtime() -> MutexGuard<'static, ()> {
GLOBAL_RUNTIME_MUTEX.lock().await
}

pub(crate) fn blocking_lock_global_runtime() -> MutexGuard<'static, ()> {
GLOBAL_RUNTIME_MUTEX.blocking_lock()
}
}
11 changes: 11 additions & 0 deletions crates/adaptive/src/plugin_component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ fn validate_adaptive_plugin_config(plugin_config: &Map<String, Json>) -> Vec<Con
"state",
"telemetry",
"adaptive_hints",
"agent_context",
"tool_parallelism",
"acg",
"policy",
Expand Down Expand Up @@ -256,6 +257,16 @@ fn validate_adaptive_plugin_config(plugin_config: &Map<String, Json>) -> Vec<Con
);
}

if let Some(agent_context_json) = plugin_config.get("agent_context").and_then(Json::as_object) {
validate_unknown_fields(
&mut diagnostics,
&config.policy,
Some("agent_context".to_string()),
agent_context_json,
&["priority", "break_chain", "inject_body_path"],
);
}

if let Some(tool_parallelism_json) = plugin_config
.get("tool_parallelism")
.and_then(Json::as_object)
Expand Down
43 changes: 41 additions & 2 deletions crates/adaptive/src/runtime/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ use crate::acg_component::{
};
use crate::acg_learner::AcgLearner;
use crate::adaptive_hints_intercept::AdaptiveHintsIntercept;
use crate::agent_context_intercept::AgentContextIntercept;
use crate::cache_diagnostics::{self, CacheDiagnosticsTracker};
use crate::config::{
AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, TelemetryComponentConfig,
ToolParallelismComponentConfig,
AcgComponentConfig, AdaptiveConfig, AdaptiveHintsComponentConfig, AgentContextComponentConfig,
TelemetryComponentConfig, ToolParallelismComponentConfig,
};
use crate::context_helpers::resolve_agent_id;
use crate::error::{AdaptiveError, Result};
Expand Down Expand Up @@ -455,6 +456,9 @@ impl AdaptiveRuntime {
self.runtime_id,
)));
}
if let Some(config) = self.config.agent_context.clone() {
pending.push(Box::new(AgentContextFeature::new(config, self.runtime_id)));
}
if let Some(config) = self.config.tool_parallelism.clone() {
pending.push(Box::new(ToolParallelismFeature::new(
config,
Expand Down Expand Up @@ -634,6 +638,41 @@ impl AdaptiveFeature for AdaptiveHintsFeature {
}
}

struct AgentContextFeature {
name: String,
priority: i32,
break_chain: bool,
config: AgentContextComponentConfig,
}

impl AgentContextFeature {
fn new(config: AgentContextComponentConfig, runtime_id: Uuid) -> Self {
Self {
name: format!("adaptive_{runtime_id}_agent_context_request"),
priority: config.priority,
break_chain: config.break_chain,
config,
}
}
}

impl AdaptiveFeature for AgentContextFeature {
fn register<'a>(
&'a mut self,
ctx: &'a mut RegistrationContext<'_>,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
let intercept = AgentContextIntercept::new(self.config.clone());
ctx.register_llm_request_intercept(
&self.name,
self.priority,
self.break_chain,
intercept.into_request_fn(),
)
})
}
}

struct ToolParallelismFeature {
name: String,
priority: i32,
Expand Down
Loading
Loading