Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/rust-sdk-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@ jobs:
run: pwsh.exe -Command "Write-Host 'PowerShell ready'"

- name: cargo test
timeout-minutes: 90
env:
RUST_E2E_CONCURRENCY: 4
COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }}
COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }}
run: cargo test --features test-support
run: cargo test --features test-support -- --test-threads=4 --nocapture

# Validates the `embedded-cli` build path on all three supported
# platforms. This is the only place `build.rs` actually runs (the
Expand Down
10 changes: 10 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ parking_lot = "0.12"
regex = "1"
sha2 = { version = "0.10", optional = true }
getrandom = "0.2"
uuid = { version = "1", default-features = false, features = ["v4"] }
zstd = { version = "0.13", optional = true }

[dev-dependencies]
Expand Down
22 changes: 19 additions & 3 deletions rust/src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Instant;

use parking_lot::RwLock;
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::task::JoinHandle;
use tracing::{Instrument, debug, error, warn};

use crate::{Error, ProtocolError};
Expand Down Expand Up @@ -184,6 +185,8 @@ pub struct JsonRpcClient {
pending_requests: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
notification_tx: broadcast::Sender<JsonRpcNotification>,
request_tx: mpsc::UnboundedSender<JsonRpcRequest>,
read_task: Mutex<Option<JoinHandle<()>>>,
write_task: Mutex<Option<JoinHandle<()>>>,
}

impl JsonRpcClient {
Expand All @@ -202,22 +205,24 @@ impl JsonRpcClient {
let (write_tx, write_rx) = mpsc::unbounded_channel::<WriteCommand>();

let writer_span = tracing::error_span!("jsonrpc_write_loop");
tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));
let write_task = tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span));

let client = Self {
request_id: AtomicU64::new(1),
write_tx,
pending_requests: Arc::new(RwLock::new(HashMap::new())),
notification_tx,
request_tx,
read_task: Mutex::new(None),
write_task: Mutex::new(Some(write_task)),
};

let pending_requests = client.pending_requests.clone();
let notification_tx_clone = client.notification_tx.clone();
let request_tx_clone = client.request_tx.clone();
let reader_span = tracing::error_span!("jsonrpc_read_loop");

tokio::spawn(
let read_task = tokio::spawn(
async move {
Self::read_loop(
reader,
Expand All @@ -229,10 +234,21 @@ impl JsonRpcClient {
}
.instrument(reader_span),
);
*client.read_task.lock() = Some(read_task);

client
}

pub(crate) fn force_close(&self) {
if let Some(task) = self.read_task.lock().take() {
task.abort();
}
if let Some(task) = self.write_task.lock().take() {
task.abort();
}
self.pending_requests.write().clear();
}

/// Writer-actor task. Owns the `AsyncWrite`, drains the command queue,
/// and writes each frame atomically (header + body + flush) before
/// signaling the ack.
Expand Down
199 changes: 166 additions & 33 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ pub enum SessionError {
/// non-empty.
#[error("invalid SessionFsConfig: {0}")]
InvalidSessionFsConfig(String),

/// The CLI returned a different session ID than the one the SDK registered.
#[error("CLI returned session ID {returned} after SDK registered {requested}")]
SessionIdMismatch {
/// Session ID registered by the SDK before the RPC was sent.
requested: SessionId,
/// Session ID returned by the CLI.
returned: SessionId,
},
}

/// How the SDK communicates with the CLI server.
Expand Down Expand Up @@ -873,6 +882,7 @@ struct ClientInner {
state: parking_lot::Mutex<ConnectionState>,
lifecycle_tx: broadcast::Sender<SessionLifecycleEvent>,
on_list_models: Option<Arc<dyn ListModelsHandler>>,
models_cache: parking_lot::Mutex<Arc<tokio::sync::OnceCell<Vec<Model>>>>,
session_fs_configured: bool,
on_get_trace_context: Option<Arc<dyn TraceContextProvider>>,
/// Token sent in the `connect` handshake. Auto-generated when the
Expand Down Expand Up @@ -900,6 +910,24 @@ impl Client {
if let Some(cfg) = &options.session_fs {
validate_session_fs_config(cfg)?;
}
// Auth options only make sense when the SDK spawns the CLI; with an
// external server, the server manages its own auth.
if matches!(options.transport, Transport::External { .. }) {
if options.github_token.is_some() {
return Err(Error::InvalidConfig(
"github_token cannot be used with Transport::External \
(external server manages its own auth)"
.to_string(),
));
}
if options.use_logged_in_user == Some(true) {
return Err(Error::InvalidConfig(
"use_logged_in_user cannot be used with Transport::External \
(external server manages its own auth)"
.to_string(),
));
}
}
// Validate token + transport combination. Stdio cannot use a
// connection token; auto-generate a UUID when the SDK spawns
// its own CLI in TCP mode and no explicit token was set.
Expand Down Expand Up @@ -1138,6 +1166,7 @@ impl Client {
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(256).0,
on_list_models,
models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
session_fs_configured,
on_get_trace_context,
effective_connection_token,
Expand Down Expand Up @@ -1752,10 +1781,17 @@ impl Client {
/// When [`ClientOptions::on_list_models`] is set, returns the handler's
/// result without making a `models.list` RPC. Otherwise queries the CLI.
pub async fn list_models(&self) -> Result<Vec<Model>, Error> {
if let Some(handler) = &self.inner.on_list_models {
return handler.list_models().await;
}
Ok(self.rpc().models().list().await?.models)
let cache = self.inner.models_cache.lock().clone();
let models = cache
.get_or_try_init(|| async {
if let Some(handler) = &self.inner.on_list_models {
handler.list_models().await
} else {
Ok(self.rpc().models().list().await?.models)
}
})
.await?;
Ok(models.clone())
}

/// Invoke [`ClientOptions::on_get_trace_context`] when configured,
Expand Down Expand Up @@ -1828,6 +1864,7 @@ impl Client {

let child = self.inner.child.lock().take();
*self.inner.state.lock() = ConnectionState::Disconnected;
*self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
if let Some(mut child) = child
&& let Err(e) = child.kill().await
{
Expand Down Expand Up @@ -1879,10 +1916,12 @@ impl Client {
{
error!(pid = ?pid, error = %e, "failed to send kill signal");
}
self.inner.rpc.force_close();
// Drop all session channels so any awaiters see a closed channel
// instead of waiting for responses that will never arrive.
self.inner.router.clear();
*self.inner.state.lock() = ConnectionState::Disconnected;
*self.inner.models_cache.lock() = Arc::new(tokio::sync::OnceCell::new());
}

/// Subscribe to lifecycle events.
Expand Down Expand Up @@ -2405,43 +2444,137 @@ mod tests {
policy: None,
supported_reasoning_efforts: Vec::new(),
};
let handler = Arc::new(CountingHandler {
let handler: Arc<dyn ListModelsHandler> = Arc::new(CountingHandler {
calls: Arc::clone(&calls),
models: vec![model.clone()],
});

// We can't call list_models() through Client::start without a CLI, but we
// can exercise the override path by directly constructing a Client whose
// inner has the handler set. This is the same dispatch path as the real
// call; from_streams's None default is replaced via inner construction.
let inner = ClientInner {
child: parking_lot::Mutex::new(None),
rpc: {
let (req_tx, _req_rx) = mpsc::unbounded_channel();
let (notif_tx, _notif_rx) = broadcast::channel(16);
let (read_pipe, _write_pipe) = tokio::io::duplex(64);
let (_unused_read, write_pipe) = tokio::io::duplex(64);
JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
},
cwd: PathBuf::from("."),
request_rx: parking_lot::Mutex::new(None),
notification_tx: broadcast::channel(16).0,
router: router::SessionRouter::new(),
negotiated_protocol_version: OnceLock::new(),
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(16).0,
on_list_models: Some(handler),
session_fs_configured: false,
on_get_trace_context: None,
effective_connection_token: None,
};
let client = Client {
inner: Arc::new(inner),
};
let client = client_with_list_models_handler(handler);

let result = client.list_models().await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].id, "byok-gpt-4");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn list_models_serializes_concurrent_cache_misses() {
use std::sync::atomic::{AtomicUsize, Ordering};

struct SlowCountingHandler {
calls: Arc<AtomicUsize>,
models: Vec<Model>,
}
#[async_trait]
impl ListModelsHandler for SlowCountingHandler {
async fn list_models(&self) -> Result<Vec<Model>, Error> {
self.calls.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
Ok(self.models.clone())
}
}

let calls = Arc::new(AtomicUsize::new(0));
let model = Model {
billing: None,
capabilities: ModelCapabilities {
limits: None,
supports: None,
},
default_reasoning_effort: None,
id: "single-flight-model".into(),
name: "Single Flight Model".into(),
policy: None,
supported_reasoning_efforts: Vec::new(),
};
let handler: Arc<dyn ListModelsHandler> = Arc::new(SlowCountingHandler {
calls: Arc::clone(&calls),
models: vec![model],
});
let client = client_with_list_models_handler(handler);

let (first, second) = tokio::join!(client.list_models(), client.list_models());
assert_eq!(first.unwrap()[0].id, "single-flight-model");
assert_eq!(second.unwrap()[0].id, "single-flight-model");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}

#[tokio::test]
async fn cancelled_create_session_unregisters_pending_session() {
let (client_write, _server_read) = tokio::io::duplex(8192);
let (_server_write, client_read) = tokio::io::duplex(8192);
let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
let handle = tokio::spawn({
let client = client.clone();
async move { client.create_session(SessionConfig::default()).await }
});

wait_for_pending_session_registration(&client).await;
handle.abort();
let _ = handle.await;

assert!(client.inner.router.session_ids().is_empty());
client.force_stop();
}

#[tokio::test]
async fn cancelled_resume_session_unregisters_pending_session() {
let (client_write, _server_read) = tokio::io::duplex(8192);
let (_server_write, client_read) = tokio::io::duplex(8192);
let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap();
let session_id = SessionId::new("resume-cancel-test");
let handle = tokio::spawn({
let client = client.clone();
async move {
client
.resume_session(ResumeSessionConfig::new(session_id))
.await
}
});

wait_for_pending_session_registration(&client).await;
handle.abort();
let _ = handle.await;

assert!(client.inner.router.session_ids().is_empty());
client.force_stop();
}

fn client_with_list_models_handler(handler: Arc<dyn ListModelsHandler>) -> Client {
Client {
inner: Arc::new(ClientInner {
child: parking_lot::Mutex::new(None),
rpc: {
let (req_tx, _req_rx) = mpsc::unbounded_channel();
let (notif_tx, _notif_rx) = broadcast::channel(16);
let (read_pipe, _write_pipe) = tokio::io::duplex(64);
let (_unused_read, write_pipe) = tokio::io::duplex(64);
JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx)
},
cwd: PathBuf::from("."),
request_rx: parking_lot::Mutex::new(None),
notification_tx: broadcast::channel(16).0,
router: router::SessionRouter::new(),
negotiated_protocol_version: OnceLock::new(),
state: parking_lot::Mutex::new(ConnectionState::Connected),
lifecycle_tx: broadcast::channel(16).0,
on_list_models: Some(handler),
models_cache: parking_lot::Mutex::new(Arc::new(tokio::sync::OnceCell::new())),
session_fs_configured: false,
on_get_trace_context: None,
effective_connection_token: None,
}),
}
}

async fn wait_for_pending_session_registration(client: &Client) {
let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(1);
while client.inner.router.session_ids().is_empty() {
assert!(
tokio::time::Instant::now() < deadline,
"session was not registered"
);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
Loading
Loading