From 2b0b483c606a7dabe3997a23d19ab3cd705bc112 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 12:32:10 -0400 Subject: [PATCH 1/8] feat: expose prediction SSE streams --- crates/coglet/src/orchestrator.rs | 6 +- crates/coglet/src/prediction.rs | 232 ++++++++++++++- crates/coglet/src/service.rs | 274 +++++++++++++++++- crates/coglet/src/transport/http/routes.rs | 149 +++++++++- .../tests/sse_streaming_output.txtar | 36 +++ 5 files changed, 680 insertions(+), 17 deletions(-) create mode 100644 integration-tests/tests/sse_streaming_output.txtar diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index c0aec8b9ed..5694338e35 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -966,7 +966,7 @@ async fn run_event_loop( Ok(SlotResponse::LogLine { source, data }) => { let (prediction_id, poisoned) = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_log(&data); + p.append_log_source(source, &data); (Some(p.id().to_string()), false) } else { (None, true) @@ -1015,10 +1015,10 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::OutputChunk { output, index: _ }) => { + Ok(SlotResponse::OutputChunk { output, index }) => { let poisoned = if let Some(pred) = predictions.get(&slot_id) { if let Some(mut p) = try_lock_prediction(pred) { - p.append_output(output); + p.append_output_chunk(output, index); false } else { true diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 81ab018b40..1e86a20754 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -7,7 +7,7 @@ use std::time::Instant; use tokio::sync::Notify; pub use tokio_util::sync::CancellationToken; -use crate::bridge::protocol::MetricMode; +use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -64,6 +64,70 @@ impl PredictionOutput { } } +#[derive(Debug, Clone)] +pub enum PredictionStreamEvent { + Start { + id: String, + status: String, + }, + Output { + chunk: serde_json::Value, + index: u64, + }, + Log { + source: LogSource, + data: String, + }, + Metric { + name: String, + value: serde_json::Value, + mode: MetricMode, + }, + Completed { + payload: serde_json::Value, + }, +} + +pub struct PredictionStreamReplay { + pub replay: Vec, + pub receiver: tokio::sync::broadcast::Receiver, +} + +impl PredictionStreamEvent { + pub fn event_name(&self) -> &'static str { + match self { + Self::Start { .. } => "start", + Self::Output { .. } => "output", + Self::Log { .. } => "log", + Self::Metric { .. } => "metric", + Self::Completed { .. } => "completed", + } + } + + pub fn json_data(&self) -> serde_json::Value { + match self { + Self::Start { id, status } => serde_json::json!({ + "id": id, + "status": status, + }), + Self::Output { chunk, index } => serde_json::json!({ + "chunk": chunk, + "index": index, + }), + Self::Log { source, data } => serde_json::json!({ + "source": source, + "data": data, + }), + Self::Metric { name, value, mode } => serde_json::json!({ + "name": name, + "value": value, + "mode": mode, + }), + Self::Completed { payload } => payload.clone(), + } + } +} + /// Prediction lifecycle state. pub struct Prediction { id: String, @@ -76,12 +140,16 @@ pub struct Prediction { error: Option, webhook: Option, completion: Arc, + stream_tx: tokio::sync::broadcast::Sender, + stream_history: Vec, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } impl Prediction { pub fn new(id: String, webhook: Option) -> Self { + let (stream_tx, _) = tokio::sync::broadcast::channel(1024); + Self { id, cancel_token: CancellationToken::new(), @@ -93,6 +161,8 @@ impl Prediction { error: None, webhook, completion: Arc::new(Notify::new()), + stream_tx, + stream_history: Vec::new(), metrics: HashMap::new(), } } @@ -105,6 +175,26 @@ impl Prediction { self.cancel_token.clone() } + pub fn subscribe_stream(&self) -> tokio::sync::broadcast::Receiver { + self.stream_tx.subscribe() + } + + pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { + PredictionStreamReplay { + replay: self.stream_history.clone(), + receiver: self.stream_tx.subscribe(), + } + } + + pub fn stream_receiver_count(&self) -> usize { + self.stream_tx.receiver_count() + } + + fn emit_stream_event(&mut self, event: PredictionStreamEvent) { + self.stream_history.push(event.clone()); + let _ = self.stream_tx.send(event); + } + pub fn is_canceled(&self) -> bool { self.cancel_token.is_cancelled() } @@ -119,6 +209,10 @@ impl Prediction { pub fn set_processing(&mut self) { self.status = PredictionStatus::Processing; + self.emit_stream_event(PredictionStreamEvent::Start { + id: self.id.clone(), + status: self.status.as_str().to_string(), + }); self.fire_webhook(WebhookEventType::Start); } @@ -128,6 +222,9 @@ impl Prediction { } self.status = PredictionStatus::Succeeded; self.output = Some(output); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); // notify_one stores a permit so a future .notified().await will // consume it immediately. notify_waiters only wakes currently- @@ -144,6 +241,9 @@ impl Prediction { } self.status = PredictionStatus::Failed; self.error = Some(error); + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -153,6 +253,9 @@ impl Prediction { return; } self.status = PredictionStatus::Canceled; + self.emit_stream_event(PredictionStreamEvent::Completed { + payload: self.build_state_snapshot(), + }); self.fire_terminal_webhook(); self.completion.notify_one(); } @@ -162,7 +265,15 @@ impl Prediction { } pub fn append_log(&mut self, data: &str) { + self.append_log_source(LogSource::Stdout, data); + } + + pub fn append_log_source(&mut self, source: LogSource, data: &str) { self.logs.push_str(data); + self.emit_stream_event(PredictionStreamEvent::Log { + source, + data: data.to_string(), + }); self.fire_webhook(WebhookEventType::Logs); } @@ -184,6 +295,12 @@ impl Prediction { return; } + self.emit_stream_event(PredictionStreamEvent::Metric { + name: name.clone(), + value: value.clone(), + mode, + }); + // Dot-path resolution: "a.b.c" → nested objects let parts: Vec<&str> = name.split('.').collect(); if parts.len() > 1 { @@ -297,7 +414,16 @@ impl Prediction { } pub fn append_output(&mut self, output: serde_json::Value) { - self.outputs.push(output); + let index = self.outputs.len() as u64; + self.append_output_chunk(output, index); + } + + pub fn append_output_chunk(&mut self, output: serde_json::Value, index: u64) { + self.outputs.push(output.clone()); + self.emit_stream_event(PredictionStreamEvent::Output { + chunk: output, + index, + }); self.fire_webhook(WebhookEventType::Output); } @@ -484,6 +610,108 @@ mod tests { assert_eq!(pred.outputs().len(), 2); } + #[tokio::test] + async fn prediction_stream_emits_start_output_log_and_completed() { + let mut prediction = Prediction::new("pred_stream".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.append_log("loading\n"); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let start = rx.recv().await.unwrap(); + assert_eq!(start.event_name(), "start"); + assert_eq!( + start.json_data(), + serde_json::json!({"id":"pred_stream","status":"processing"}) + ); + + let output = rx.recv().await.unwrap(); + assert_eq!(output.event_name(), "output"); + assert_eq!( + output.json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + + let log = rx.recv().await.unwrap(); + assert_eq!(log.event_name(), "log"); + assert_eq!( + log.json_data(), + serde_json::json!({"source":"stdout","data":"loading\n"}) + ); + + let completed = rx.recv().await.unwrap(); + assert_eq!(completed.event_name(), "completed"); + assert_eq!(completed.json_data()["id"], "pred_stream"); + assert_eq!(completed.json_data()["status"], "succeeded"); + assert_eq!( + completed.json_data()["output"], + serde_json::json!(["hello"]) + ); + } + + #[tokio::test] + async fn prediction_stream_emits_metric_event() { + let mut prediction = Prediction::new("pred_metric".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.set_metric( + "tokens".to_string(), + serde_json::json!(1), + MetricMode::Increment, + ); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "metric"); + assert_eq!( + event.json_data(), + serde_json::json!({ + "name":"tokens", + "value":1, + "mode":"increment" + }) + ); + } + + #[tokio::test] + async fn prediction_stream_preserves_log_source() { + let mut prediction = Prediction::new("pred_log_source".to_string(), None); + let mut rx = prediction.subscribe_stream(); + + prediction.append_log_source(crate::bridge::protocol::LogSource::Stderr, "warning\n"); + + let event = rx.recv().await.unwrap(); + assert_eq!(event.event_name(), "log"); + assert_eq!( + event.json_data(), + serde_json::json!({"source":"stderr","data":"warning\n"}) + ); + } + + #[tokio::test] + async fn prediction_stream_replay_includes_already_emitted_events() { + let mut prediction = Prediction::new("pred_replay".to_string(), None); + + prediction.set_processing(); + prediction.append_output_chunk(serde_json::json!("hello"), 0); + prediction.set_succeeded(PredictionOutput::Stream(vec![serde_json::json!("hello")])); + + let replay = prediction.subscribe_stream_replay(); + let events: Vec<&str> = replay + .replay + .iter() + .map(|event| event.event_name()) + .collect(); + + assert_eq!(events, vec!["start", "output", "completed"]); + assert_eq!( + replay.replay[1].json_data(), + serde_json::json!({"chunk":"hello","index":0}) + ); + assert_eq!(replay.replay[2].json_data()["status"], "succeeded"); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 300ccab9e7..80a42676bf 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -79,6 +79,7 @@ struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, + respond_async: bool, } /// Handle to a submitted prediction for cancellation on disconnect. @@ -106,6 +107,49 @@ impl PredictionHandle { } } +pub struct PredictionStreamSubscription { + id: String, + replay: Vec, + receiver: tokio::sync::broadcast::Receiver, + guard: PredictionStreamGuard, +} + +impl PredictionStreamSubscription { + pub fn prediction_id(&self) -> &str { + &self.id + } + + pub fn into_parts( + self, + ) -> ( + Vec, + tokio::sync::broadcast::Receiver, + PredictionStreamGuard, + ) { + (self.replay, self.receiver, self.guard) + } +} + +pub struct PredictionStreamGuard { + id: String, + service: Arc, + respond_async: bool, +} + +impl Drop for PredictionStreamGuard { + fn drop(&mut self) { + if self.respond_async { + return; + } + + if self.service.stream_receiver_count(&self.id) == 0 + && !self.service.prediction_is_terminal(&self.id) + { + self.service.cancel(&self.id); + } + } +} + /// Guard for sync predictions - cancels on drop unless disarmed. /// /// When the HTTP connection drops (client disconnect), axum drops the @@ -415,6 +459,7 @@ impl PredictionService { id: String, input: serde_json::Value, webhook: Option, + respond_async: bool, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { @@ -442,6 +487,7 @@ impl PredictionService { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, + respond_async, }, ); @@ -469,6 +515,45 @@ impl PredictionService { Some(response) } + pub fn subscribe_prediction_stream( + self: &Arc, + id: &str, + ) -> Option { + let entry = self.predictions.get(id)?; + let stream = entry.prediction.lock().ok()?.subscribe_stream_replay(); + let respond_async = entry.respond_async; + Some(PredictionStreamSubscription { + id: id.to_string(), + replay: stream.replay, + receiver: stream.receiver, + guard: PredictionStreamGuard { + id: id.to_string(), + service: Arc::clone(self), + respond_async, + }, + }) + } + + fn stream_receiver_count(&self, id: &str) -> usize { + self.predictions + .get(id) + .and_then(|entry| { + entry + .prediction + .lock() + .ok() + .map(|p| p.stream_receiver_count()) + }) + .unwrap_or(0) + } + + fn prediction_is_terminal(&self, id: &str) -> bool { + self.predictions + .get(id) + .and_then(|entry| entry.prediction.lock().ok().map(|p| p.is_terminal())) + .unwrap_or(true) + } + /// Run a prediction to completion via orchestrator. pub async fn predict( &self, @@ -760,6 +845,51 @@ mod tests { } } + struct CountingCancelOrchestrator { + cancel_count: AtomicUsize, + } + + impl CountingCancelOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CountingCancelOrchestrator { + async fn register_prediction( + &self, + _slot_id: SlotId, + _prediction: Arc>, + _idle_sender: tokio::sync::oneshot::Sender, + ) { + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; @@ -863,7 +993,7 @@ mod tests { let svc = PredictionService::new_no_pool(); let result = svc - .submit_prediction("test".to_string(), serde_json::json!({}), None) + .submit_prediction("test".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::NotReady))); } @@ -908,7 +1038,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -916,6 +1046,113 @@ mod tests { assert!(svc.prediction_exists("test-1")); } + #[tokio::test] + async fn subscribe_prediction_stream_returns_receiver_for_existing_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::new()); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction("stream-test".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("stream-test").unwrap(); + assert_eq!(subscription.prediction_id(), "stream-test"); + } + + #[tokio::test] + async fn dropping_only_sync_stream_subscription_cancels_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "sync-stream".to_string(), + serde_json::json!({}), + None, + false, + ) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_async_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "async-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("async-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + + #[tokio::test] + async fn dropping_completed_sync_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (_handle, _slot) = svc + .submit_prediction( + "completed-sync-stream".to_string(), + serde_json::json!({}), + None, + false, + ) + .await + .unwrap(); + + { + let entry = svc.predictions.get("completed-sync-stream").unwrap(); + let mut prediction = entry.prediction.lock().unwrap(); + prediction.set_succeeded(crate::PredictionOutput::Single(serde_json::json!("done"))); + } + + let subscription = svc + .subscribe_prediction_stream("completed-sync-stream") + .unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 0); + } + #[tokio::test] async fn submit_returns_at_capacity_when_no_slots() { let svc = PredictionService::new_no_pool(); @@ -927,13 +1164,13 @@ mod tests { // First prediction takes the only slot let (_handle1, _slot1) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); // Second should fail with AtCapacity let result = svc - .submit_prediction("test-2".to_string(), serde_json::json!({}), None) + .submit_prediction("test-2".to_string(), serde_json::json!({}), None, false) .await; assert!(matches!(result, Err(CreatePredictionError::AtCapacity))); } @@ -953,6 +1190,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -986,7 +1224,7 @@ mod tests { // After acquiring slot let (_handle, _slot) = svc - .submit_prediction("test-1".to_string(), serde_json::json!({}), None) + .submit_prediction("test-1".to_string(), serde_json::json!({}), None, false) .await .unwrap(); let health = svc.health().await; @@ -1009,6 +1247,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1048,6 +1287,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1087,6 +1327,7 @@ mod tests { "test-1".to_string(), serde_json::json!({"prompt": "hello"}), None, + false, ) .await .unwrap(); @@ -1113,7 +1354,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-cancel".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-cancel".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1139,7 +1385,7 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-guard".to_string(), serde_json::json!({}), None) + .submit_prediction("test-guard".to_string(), serde_json::json!({}), None, false) .await .unwrap(); @@ -1163,7 +1409,12 @@ mod tests { svc.set_health(Health::Ready).await; let (handle, _slot) = svc - .submit_prediction("test-disarm".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-disarm".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); @@ -1187,7 +1438,12 @@ mod tests { svc.set_health(Health::Ready).await; let (_handle, _slot) = svc - .submit_prediction("test-remove".to_string(), serde_json::json!({}), None) + .submit_prediction( + "test-remove".to_string(), + serde_json::json!({}), + None, + false, + ) .await .unwrap(); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index aa9875340e..fd1b7cdc8f 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -1,12 +1,17 @@ //! HTTP route handlers. +use std::convert::Infallible; use std::sync::Arc; +use std::time::Duration; use axum::{ Router, extract::{DefaultBodyLimit, Path, State}, http::{HeaderMap, StatusCode}, - response::{IntoResponse, Json}, + response::{ + IntoResponse, Json, Response, + sse::{Event, KeepAlive, Sse}, + }, routing::{get, post, put}, }; use serde::{Deserialize, Serialize}; @@ -15,7 +20,9 @@ use serde::{Deserialize, Serialize}; use crate::health::Health; use crate::health::{HealthResponse, SetupResult}; use crate::predictor::PredictionError; -use crate::service::{CreatePredictionError, HealthSnapshot, PredictionService}; +use crate::service::{ + CreatePredictionError, HealthSnapshot, PredictionService, PredictionStreamSubscription, +}; use crate::version::VersionInfo; use crate::webhook::{TraceContext, WebhookConfig, WebhookEventType, WebhookSender}; @@ -376,7 +383,12 @@ async fn create_prediction_with_id( // Submit prediction: creates Prediction, acquires slot, registers in service let (handle, unregistered_slot) = match service - .submit_prediction(prediction_id.clone(), input.clone(), webhook_sender) + .submit_prediction( + prediction_id.clone(), + input.clone(), + webhook_sender, + respond_async, + ) .await { Ok(r) => r, @@ -557,6 +569,80 @@ async fn cancel_prediction( } } +fn stream_event_to_sse(event: crate::prediction::PredictionStreamEvent) -> Event { + Event::default() + .event(event.event_name()) + .json_data(event.json_data()) + .expect("prediction stream events serialize to JSON") +} + +fn prediction_sse_stream( + subscription: PredictionStreamSubscription, +) -> impl futures::Stream> { + let (replay, receiver, guard) = subscription.into_parts(); + + struct StreamState { + replay: std::collections::VecDeque, + receiver: tokio::sync::broadcast::Receiver, + _guard: crate::service::PredictionStreamGuard, + done: bool, + } + + futures::stream::unfold( + StreamState { + replay: replay.into(), + receiver, + _guard: guard, + done: false, + }, + |mut state| async move { + if state.done { + return None; + } + + if let Some(event) = state.replay.pop_front() { + state.done = event.event_name() == "completed"; + return Some((Ok(stream_event_to_sse(event)), state)); + } + + loop { + match state.receiver.recv().await { + Ok(event) => { + state.done = event.event_name() == "completed"; + return Some((Ok(stream_event_to_sse(event)), state)); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "SSE prediction stream receiver lagged"); + continue; + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => return None, + } + } + }, + ) +} + +async fn stream_prediction( + State(service): State>, + Path(prediction_id): Path, +) -> Response { + let Some(subscription) = service.subscribe_prediction_stream(&prediction_id) else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(); + }; + + Sse::new(prediction_sse_stream(subscription)) + .keep_alive( + KeepAlive::new() + .interval(Duration::from_secs(15)) + .text("keep-alive"), + ) + .into_response() +} + async fn shutdown(State(service): State>) -> impl IntoResponse { tracing::info!("Shutdown requested via HTTP"); service.trigger_shutdown(); @@ -679,6 +765,7 @@ pub fn routes(service: Arc) -> Router { .route("/shutdown", post(shutdown)) .route("/predictions", post(create_prediction)) .route("/predictions/{id}", put(create_prediction_idempotent)) + .route("/predictions/{id}/stream", get(stream_prediction)) .route("/predictions/{id}/cancel", post(cancel_prediction)) .route("/trainings", post(create_training)) .route("/trainings/{id}", put(create_training_idempotent)) @@ -955,6 +1042,62 @@ mod tests { assert_eq!(json["status"], "starting"); } + #[tokio::test] + async fn stream_prediction_unknown_id_returns_404() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/missing/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + let json = response_json(response).await; + assert_eq!(json["error"], "Prediction not found"); + } + + #[tokio::test] + async fn stream_prediction_existing_id_returns_sse() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "stream-route".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/stream-route/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar new file mode 100644 index 0000000000..76e7d4d47d --- /dev/null +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -0,0 +1,36 @@ +# Test that async generator output is available over the SSE stream endpoint. + +[short] skip 'requires Docker build' + +cog build -t $TEST_IMAGE +cog serve --upload-url http://unused/ + +curl -H Prefer:respond-async PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +stdout '"status":"starting"' + +curl -N -H Accept:text/event-stream GET /predictions/sse-stream-test/stream +stdout 'event: output' +stdout 'data: {"chunk":"chunk-1","index":0}' +stdout 'event: output' +stdout 'data: {"chunk":"chunk-2","index":1}' +stdout 'event: completed' +stdout '"status":"succeeded"' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +import time +from typing import Iterator + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[str]: + time.sleep(0.25) + yield "chunk-1" + time.sleep(0.25) + yield "chunk-2" From c61e323d0b6a24e89962a4e1c267c4b5a79f0a8e Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 13:00:14 -0400 Subject: [PATCH 2/8] docs: add streaming text example --- examples/streaming-text/.dockerignore | 3 ++ examples/streaming-text/.gitignore | 3 ++ examples/streaming-text/README.md | 58 +++++++++++++++++++++ examples/streaming-text/cog.yaml | 7 +++ examples/streaming-text/predict.py | 64 ++++++++++++++++++++++++ examples/streaming-text/requirements.txt | 3 ++ 6 files changed, 138 insertions(+) create mode 100644 examples/streaming-text/.dockerignore create mode 100644 examples/streaming-text/.gitignore create mode 100644 examples/streaming-text/README.md create mode 100644 examples/streaming-text/cog.yaml create mode 100644 examples/streaming-text/predict.py create mode 100644 examples/streaming-text/requirements.txt diff --git a/examples/streaming-text/.dockerignore b/examples/streaming-text/.dockerignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.dockerignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/.gitignore b/examples/streaming-text/.gitignore new file mode 100644 index 0000000000..9118d0e055 --- /dev/null +++ b/examples/streaming-text/.gitignore @@ -0,0 +1,3 @@ +.cog/ +__pycache__/ +*.pyc diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md new file mode 100644 index 0000000000..947d5b6625 --- /dev/null +++ b/examples/streaming-text/README.md @@ -0,0 +1,58 @@ +# examples/streaming-text + +Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. + +This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks from the Server-Sent Events stream endpoint. + +## Run a normal prediction + +From this directory: + +```sh +cog predict -i prompt="Write a short haiku about databases" +``` + +This returns the final accumulated output after the prediction completes. + +## Stream output over HTTP + +Start the server: + +```sh +cog serve +``` + +Create an async prediction with a fixed ID: + +```sh +curl -s -X PUT http://localhost:5000/predictions/streaming-demo \ + -H 'Content-Type: application/json' \ + -H 'Prefer: respond-async' \ + -d '{"input":{"prompt":"Write a short haiku about databases","max_new_tokens":96}}' +``` + +Then subscribe to its stream: + +```sh +curl -N -H 'Accept: text/event-stream' \ + http://localhost:5000/predictions/streaming-demo/stream +``` + +The response includes `output` events as chunks are generated, followed by a `completed` event: + +```text +event: output +data: {"chunk":"Silent","index":0} + +event: output +data: {"chunk":" rows","index":1} + +event: completed +data: {"id":"streaming-demo","status":"succeeded",...} +``` + +## How it works + +`predict.py` returns `Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. + +The normal prediction response still contains the accumulated output for compatibility. The stream endpoint is useful when clients want to display tokens as they arrive. diff --git a/examples/streaming-text/cog.yaml b/examples/streaming-text/cog.yaml new file mode 100644 index 0000000000..68866b7aeb --- /dev/null +++ b/examples/streaming-text/cog.yaml @@ -0,0 +1,7 @@ +# Streaming text generation example using a small open-weight language model. + +build: + python_version: "3.12" + python_requirements: requirements.txt + +predict: "predict.py:Predictor" diff --git a/examples/streaming-text/predict.py b/examples/streaming-text/predict.py new file mode 100644 index 0000000000..3ed8772e54 --- /dev/null +++ b/examples/streaming-text/predict.py @@ -0,0 +1,64 @@ +from threading import Thread +from typing import Iterator + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +from cog import BasePredictor, Input + +MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct" + + +class Predictor(BasePredictor): + def setup(self) -> None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.float16 if self.device == "cuda" else torch.float32 + + self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + self.model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + torch_dtype=dtype, + ).to(self.device) + self.model.eval() + + def predict( + self, + prompt: str = Input(description="Prompt to complete"), + max_new_tokens: int = Input( + description="Maximum number of tokens to generate", + default=128, + ge=1, + le=512, + ), + ) -> Iterator[str]: + messages = [{"role": "user", "content": prompt}] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = self.tokenizer([text], return_tensors="pt").to(self.device) + streamer = TextIteratorStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + generation_kwargs = { + **inputs, + "streamer": streamer, + "max_new_tokens": max_new_tokens, + "do_sample": True, + "temperature": 0.7, + "top_p": 0.9, + "pad_token_id": self.tokenizer.eos_token_id, + } + + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + + for chunk in streamer: + if chunk: + yield chunk + + thread.join() diff --git a/examples/streaming-text/requirements.txt b/examples/streaming-text/requirements.txt new file mode 100644 index 0000000000..916933a3de --- /dev/null +++ b/examples/streaming-text/requirements.txt @@ -0,0 +1,3 @@ +torch==2.7.1 +transformers==4.51.3 +accelerate==1.6.0 From bf064d22475bd6b0f7a3cbf0008f8b372f4f48eb Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 15:58:26 -0400 Subject: [PATCH 3/8] feat: stream predictions via accept header --- crates/coglet/src/transport/http/routes.rs | 196 ++++++++++++++---- examples/streaming-text/README.md | 17 +- .../tests/sse_streaming_output.txtar | 7 +- 3 files changed, 163 insertions(+), 57 deletions(-) diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index fd1b7cdc8f..1cdc6f6c18 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -216,6 +216,43 @@ fn should_respond_async(headers: &HeaderMap) -> bool { .unwrap_or(false) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PredictionResponseMode { + SyncJson, + AsyncJson, + AsyncSse, +} + +fn wants_sse(headers: &HeaderMap) -> bool { + headers + .get(axum::http::header::ACCEPT) + .and_then(|value| value.to_str().ok()) + .map(|accept| { + accept + .split(',') + .any(|part| part.trim().split(';').next() == Some("text/event-stream")) + }) + .unwrap_or(false) +} + +fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if wants_sse(headers) { + PredictionResponseMode::AsyncSse + } else if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + +fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { + if should_respond_async(headers) { + PredictionResponseMode::AsyncJson + } else { + PredictionResponseMode::SyncJson + } +} + fn extract_trace_context(headers: &HeaderMap) -> TraceContext { TraceContext { traceparent: headers @@ -233,7 +270,7 @@ async fn create_prediction( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -242,7 +279,7 @@ async fn create_prediction( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = prediction_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -251,7 +288,7 @@ async fn create_prediction( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -263,7 +300,7 @@ async fn create_prediction_idempotent( Path(prediction_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -284,15 +321,20 @@ async fn create_prediction_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } + let response_mode = prediction_response_mode(&headers); + // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { - return (StatusCode::ACCEPTED, Json(response)); + if response_mode == PredictionResponseMode::AsyncSse { + return stream_prediction_response(service, &prediction_id); + } + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -301,7 +343,7 @@ async fn create_prediction_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, false, ) @@ -340,10 +382,10 @@ async fn create_prediction_with_id( context: std::collections::HashMap, webhook: Option, webhook_events_filter: Vec, - respond_async: bool, + response_mode: PredictionResponseMode, trace_context: TraceContext, is_training: bool, -) -> (StatusCode, Json) { +) -> Response { // Strip unknown fields and validate in one pass. Unknown inputs are // silently dropped to match Replicate's historical API behavior. let (stripped, validation_result) = if is_training { @@ -372,7 +414,8 @@ async fn create_prediction_with_id( return ( StatusCode::UNPROCESSABLE_ENTITY, Json(serde_json::json!({ "detail": detail })), - ); + ) + .into_response(); } let webhook_sender = build_webhook_sender( @@ -387,7 +430,7 @@ async fn create_prediction_with_id( prediction_id.clone(), input.clone(), webhook_sender, - respond_async, + response_mode != PredictionResponseMode::SyncJson, ) .await { @@ -400,7 +443,8 @@ async fn create_prediction_with_id( "error": msg, "status": "failed" })), - ); + ) + .into_response(); } Err(CreatePredictionError::AtCapacity) => { return ( @@ -409,14 +453,15 @@ async fn create_prediction_with_id( "error": "At capacity - all prediction slots busy", "status": "failed" })), - ); + ) + .into_response(); } }; let prediction = unregistered_slot.prediction(); // Async mode: spawn background task, return immediately - if respond_async { + if response_mode != PredictionResponseMode::SyncJson { let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); @@ -429,13 +474,18 @@ async fn create_prediction_with_id( service_clone.remove_prediction(&id_for_cleanup); }); + if response_mode == PredictionResponseMode::AsyncSse { + return stream_prediction_response(service, &prediction_id); + } + return ( StatusCode::ACCEPTED, Json(serde_json::json!({ "id": prediction_id, "status": "starting" })), - ); + ) + .into_response(); } // Sync mode: spawn prediction into a background task so the slot lifetime @@ -501,6 +551,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::InvalidInput(msg)) => { let metrics = build_metrics(&user_metrics); @@ -514,6 +565,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::NotReady) => { let msg = PredictionError::NotReady.to_string(); @@ -526,6 +578,7 @@ async fn create_prediction_with_id( "status": "failed" })), ) + .into_response() } Err(PredictionError::Failed(msg)) => { let metrics = build_metrics(&user_metrics); @@ -540,6 +593,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } Err(PredictionError::Cancelled) => { let metrics = build_metrics(&user_metrics); @@ -552,6 +606,7 @@ async fn create_prediction_with_id( "metrics": metrics })), ) + .into_response() } } } @@ -622,11 +677,8 @@ fn prediction_sse_stream( ) } -async fn stream_prediction( - State(service): State>, - Path(prediction_id): Path, -) -> Response { - let Some(subscription) = service.subscribe_prediction_stream(&prediction_id) else { +fn stream_prediction_response(service: Arc, prediction_id: &str) -> Response { + let Some(subscription) = service.subscribe_prediction_stream(prediction_id) else { return ( StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Prediction not found"})), @@ -668,7 +720,7 @@ async fn create_training( State(service): State>, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -677,7 +729,7 @@ async fn create_training( webhook_events_filter: default_webhook_events_filter(), }); let prediction_id = request.id.unwrap_or_else(generate_prediction_id); - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -686,7 +738,7 @@ async fn create_training( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -698,7 +750,7 @@ async fn create_training_idempotent( Path(training_id): Path, headers: HeaderMap, body: Option>, -) -> impl IntoResponse { +) -> Response { let request = body.map(|Json(r)| r).unwrap_or_else(|| PredictionRequest { id: None, input: serde_json::json!({}), @@ -719,15 +771,16 @@ async fn create_training_idempotent( "type": "value_error" }] })), - ); + ) + .into_response(); } // Idempotent: return existing state if already submitted if let Some(response) = service.get_prediction_response(&training_id) { - return (StatusCode::ACCEPTED, Json(response)); + return (StatusCode::ACCEPTED, Json(response)).into_response(); } - let respond_async = should_respond_async(&headers); + let response_mode = json_response_mode(&headers); let trace_context = extract_trace_context(&headers); create_prediction_with_id( service, @@ -736,7 +789,7 @@ async fn create_training_idempotent( request.context, request.webhook, request.webhook_events_filter, - respond_async, + response_mode, trace_context, true, ) @@ -765,7 +818,6 @@ pub fn routes(service: Arc) -> Router { .route("/shutdown", post(shutdown)) .route("/predictions", post(create_prediction)) .route("/predictions/{id}", put(create_prediction_idempotent)) - .route("/predictions/{id}/stream", get(stream_prediction)) .route("/predictions/{id}/cancel", post(cancel_prediction)) .route("/trainings", post(create_training)) .route("/trainings/{id}", put(create_training_idempotent)) @@ -1043,31 +1095,67 @@ mod tests { } #[tokio::test] - async fn stream_prediction_unknown_id_returns_404() { + async fn prediction_post_with_sse_accept_returns_sse() { let service = create_ready_service().await; let app = routes(service); let response = app .oneshot( - Request::get("/predictions/missing/stream") + Request::post("/predictions") + .header("content-type", "application/json") .header("accept", "text/event-stream") - .body(Body::empty()) + .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await .unwrap(); - assert_eq!(response.status(), StatusCode::NOT_FOUND); - let json = response_json(response).await; - assert_eq!(json["error"], "Prediction not found"); + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); + } + + #[tokio::test] + async fn prediction_put_with_sse_accept_returns_sse() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::put("/predictions/sse-put") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let content_type = response.headers().get("content-type").unwrap(); + assert!( + content_type + .to_str() + .unwrap() + .starts_with("text/event-stream"), + "unexpected content-type: {:?}", + content_type + ); } #[tokio::test] - async fn stream_prediction_existing_id_returns_sse() { + async fn prediction_put_existing_with_sse_accept_returns_sse() { let service = create_ready_service().await; let (_handle, _slot) = service .submit_prediction( - "stream-route".to_string(), + "existing-sse-put".to_string(), serde_json::json!({}), None, true, @@ -1078,9 +1166,10 @@ mod tests { let response = app .oneshot( - Request::get("/predictions/stream-route/stream") + Request::put("/predictions/existing-sse-put") + .header("content-type", "application/json") .header("accept", "text/event-stream") - .body(Body::empty()) + .body(Body::from(r#"{"input":{}}"#)) .unwrap(), ) .await @@ -1098,6 +1187,33 @@ mod tests { ); } + #[tokio::test] + async fn stream_prediction_route_is_removed() { + let service = create_ready_service().await; + let (_handle, _slot) = service + .submit_prediction( + "removed-stream-route".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let app = routes(service); + + let response = app + .oneshot( + Request::get("/predictions/removed-stream-route/stream") + .header("accept", "text/event-stream") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + #[tokio::test] async fn prediction_with_custom_id() { let service = create_ready_service().await; diff --git a/examples/streaming-text/README.md b/examples/streaming-text/README.md index 947d5b6625..645664289b 100644 --- a/examples/streaming-text/README.md +++ b/examples/streaming-text/README.md @@ -2,7 +2,7 @@ Streaming text generation with `HuggingFaceTB/SmolLM2-135M-Instruct`. -This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks from the Server-Sent Events stream endpoint. +This example shows how a Cog predictor can yield text chunks as a model generates them, and how to consume those chunks with Server-Sent Events. ## Run a normal prediction @@ -22,22 +22,15 @@ Start the server: cog serve ``` -Create an async prediction with a fixed ID: +Create a prediction and request an SSE response: ```sh -curl -s -X PUT http://localhost:5000/predictions/streaming-demo \ +curl -N -X PUT http://localhost:5000/predictions/streaming-demo \ -H 'Content-Type: application/json' \ - -H 'Prefer: respond-async' \ + -H 'Accept: text/event-stream' \ -d '{"input":{"prompt":"Write a short haiku about databases","max_new_tokens":96}}' ``` -Then subscribe to its stream: - -```sh -curl -N -H 'Accept: text/event-stream' \ - http://localhost:5000/predictions/streaming-demo/stream -``` - The response includes `output` events as chunks are generated, followed by a `completed` event: ```text @@ -55,4 +48,4 @@ data: {"id":"streaming-demo","status":"succeeded",...} `predict.py` returns `Iterator[str]`. Each `yield` becomes one streamed output chunk. The example uses Hugging Face `TextIteratorStreamer` to receive generated text from `model.generate()` while generation is still running. -The normal prediction response still contains the accumulated output for compatibility. The stream endpoint is useful when clients want to display tokens as they arrive. +The normal prediction response still contains the accumulated output for compatibility. Requesting `Accept: text/event-stream` is useful when clients want to display tokens as they arrive. diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 76e7d4d47d..5f567cfd6e 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -1,14 +1,11 @@ -# Test that async generator output is available over the SSE stream endpoint. +# Test that async generator output is available when predictions are created with SSE accept. [short] skip 'requires Docker build' cog build -t $TEST_IMAGE cog serve --upload-url http://unused/ -curl -H Prefer:respond-async PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' -stdout '"status":"starting"' - -curl -N -H Accept:text/event-stream GET /predictions/sse-stream-test/stream +curl -N -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' stdout 'event: output' stdout 'data: {"chunk":"chunk-1","index":0}' stdout 'event: output' From 308ecffb4fb0dc9b69627d0e8c00907a26ccc20d Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 16:04:18 -0400 Subject: [PATCH 4/8] fix: bound prediction stream replay history --- crates/coglet/src/prediction.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 1e86a20754..d40f1b38d7 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -10,6 +10,8 @@ pub use tokio_util::sync::CancellationToken; use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; +const MAX_STREAM_HISTORY_EVENTS: usize = 1024; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { Starting, @@ -191,6 +193,9 @@ impl Prediction { } fn emit_stream_event(&mut self, event: PredictionStreamEvent) { + if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { + self.stream_history.remove(0); + } self.stream_history.push(event.clone()); let _ = self.stream_tx.send(event); } @@ -712,6 +717,28 @@ mod tests { assert_eq!(replay.replay[2].json_data()["status"], "succeeded"); } + #[tokio::test] + async fn prediction_stream_replay_is_bounded_to_recent_events() { + let mut prediction = Prediction::new("pred_replay_bounded".to_string(), None); + + prediction.set_processing(); + for index in 0..1100 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + + let replay = prediction.subscribe_stream_replay(); + + assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!( + replay.replay[0].json_data(), + serde_json::json!({"chunk":76,"index":76}) + ); + assert_eq!( + replay.replay[MAX_STREAM_HISTORY_EVENTS - 1].json_data(), + serde_json::json!({"chunk":1099,"index":1099}) + ); + } + #[tokio::test] async fn wait_returns_immediately_if_terminal() { let mut pred = Prediction::new("test".to_string(), None); From 8acf5f95755ae039146ef112095ce694e3bb5847 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Thu, 14 May 2026 16:59:47 -0400 Subject: [PATCH 5/8] fix: harden SSE prediction streaming --- crates/coglet/src/orchestrator.rs | 50 ++++-- crates/coglet/src/prediction.rs | 7 + crates/coglet/src/service.rs | 163 ++++++++++++++++-- crates/coglet/src/transport/http/routes.rs | 159 +++++++++++++++-- .../tests/sse_streaming_output.txtar | 3 +- 5 files changed, 336 insertions(+), 46 deletions(-) diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 5694338e35..8fb6735eb9 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -7,7 +7,7 @@ //! 4. Run event loop routing responses to predictions //! 5. On worker crash: fail all predictions, shut down -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::process::Stdio; use std::sync::Arc; use std::sync::Mutex as StdMutex; @@ -353,15 +353,18 @@ pub struct OrchestratorReady { pub setup_logs: String, } +type RegisterPredictionMessage = ( + SlotId, + Arc>, + tokio::sync::oneshot::Sender, + tokio::sync::oneshot::Sender<()>, +); + pub struct OrchestratorHandle { child: Child, ctrl_writer: Arc>>>, - register_tx: mpsc::Sender<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + register_tx: mpsc::Sender, healthcheck_tx: mpsc::Sender>, cancel_tx: mpsc::Sender, slot_ids: Vec, @@ -375,10 +378,12 @@ impl Orchestrator for OrchestratorHandle { prediction: Arc>, idle_sender: tokio::sync::oneshot::Sender, ) { + let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); let _ = self .register_tx - .send((slot_id, prediction, idle_sender)) + .send((slot_id, prediction, idle_sender, ack_tx)) .await; + let _ = ack_rx.await; } async fn cancel_by_prediction_id(&self, prediction_id: &str) -> Result<(), OrchestratorError> { @@ -698,11 +703,7 @@ async fn run_event_loop( SlotId, FramedRead>, )>, - mut register_rx: mpsc::Receiver<( - SlotId, - Arc>, - tokio::sync::oneshot::Sender, - )>, + mut register_rx: mpsc::Receiver, mut healthcheck_rx: mpsc::Receiver>, mut cancel_rx: mpsc::Receiver, pool: Arc, @@ -718,6 +719,7 @@ async fn run_event_loop( let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; let mut pending_uploads: HashMap>> = HashMap::new(); + let mut pending_cancellations: HashSet = HashSet::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); @@ -923,17 +925,19 @@ async fn run_event_loop( } } None => { - tracing::debug!(%prediction_id, "Cancel requested for unknown prediction (may have already completed)"); + tracing::debug!(%prediction_id, "Cancel requested for unknown prediction; storing pending cancellation"); + pending_cancellations.insert(prediction_id); } } } - Some((slot_id, prediction, idle_sender)) = register_rx.recv() => { + Some((slot_id, prediction, idle_sender, registered_tx)) = register_rx.recv() => { let prediction_id = match try_lock_prediction(&prediction) { Some(p) => p.id().to_string(), None => { // Mutex poisoned during registration - prediction already failed tracing::error!(%slot_id, "Prediction mutex poisoned during registration"); + let _ = registered_tx.send(()); continue; } }; @@ -949,6 +953,24 @@ async fn run_event_loop( ); tracing::debug!(%slot_id, %prediction_id, "Registered prediction"); predictions.insert(slot_id, prediction); + let pending_cancel = pending_cancellations.remove(&prediction_id); + let _ = registered_tx.send(()); + if pending_cancel { + tracing::info!( + target: "coglet::prediction", + %prediction_id, + %slot_id, + "Applying pending cancellation" + ); + let mut writer = ctrl_writer.lock().await; + if let Err(e) = writer.send(ControlRequest::Cancel { slot: slot_id }).await { + tracing::error!( + %slot_id, + error = %e, + "Failed to send pending cancel request to worker" + ); + } + } } Some((slot_id, result)) = slot_msg_rx.recv() => { diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index d40f1b38d7..10c87fe83e 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -92,6 +92,7 @@ pub enum PredictionStreamEvent { pub struct PredictionStreamReplay { pub replay: Vec, + pub skipped: u64, pub receiver: tokio::sync::broadcast::Receiver, } @@ -144,6 +145,7 @@ pub struct Prediction { completion: Arc, stream_tx: tokio::sync::broadcast::Sender, stream_history: Vec, + stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, } @@ -165,6 +167,7 @@ impl Prediction { completion: Arc::new(Notify::new()), stream_tx, stream_history: Vec::new(), + stream_history_skipped: 0, metrics: HashMap::new(), } } @@ -184,6 +187,7 @@ impl Prediction { pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { PredictionStreamReplay { replay: self.stream_history.clone(), + skipped: self.stream_history_skipped, receiver: self.stream_tx.subscribe(), } } @@ -195,6 +199,7 @@ impl Prediction { fn emit_stream_event(&mut self, event: PredictionStreamEvent) { if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { self.stream_history.remove(0); + self.stream_history_skipped += 1; } self.stream_history.push(event.clone()); let _ = self.stream_tx.send(event); @@ -710,6 +715,7 @@ mod tests { .collect(); assert_eq!(events, vec!["start", "output", "completed"]); + assert_eq!(replay.skipped, 0); assert_eq!( replay.replay[1].json_data(), serde_json::json!({"chunk":"hello","index":0}) @@ -729,6 +735,7 @@ mod tests { let replay = prediction.subscribe_stream_replay(); assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!(replay.skipped, 77); assert_eq!( replay.replay[0].json_data(), serde_json::json!({"chunk":76,"index":76}) diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 80a42676bf..30a1c85163 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -79,7 +79,7 @@ struct PredictionEntry { prediction: Arc>, cancel_token: CancellationToken, input: serde_json::Value, - respond_async: bool, + cancel_on_stream_drop: bool, } /// Handle to a submitted prediction for cancellation on disconnect. @@ -110,6 +110,7 @@ impl PredictionHandle { pub struct PredictionStreamSubscription { id: String, replay: Vec, + skipped: u64, receiver: tokio::sync::broadcast::Receiver, guard: PredictionStreamGuard, } @@ -123,22 +124,23 @@ impl PredictionStreamSubscription { self, ) -> ( Vec, + u64, tokio::sync::broadcast::Receiver, PredictionStreamGuard, ) { - (self.replay, self.receiver, self.guard) + (self.replay, self.skipped, self.receiver, self.guard) } } pub struct PredictionStreamGuard { id: String, service: Arc, - respond_async: bool, + cancel_on_stream_drop: bool, } impl Drop for PredictionStreamGuard { fn drop(&mut self) { - if self.respond_async { + if !self.cancel_on_stream_drop { return; } @@ -459,7 +461,7 @@ impl PredictionService { id: String, input: serde_json::Value, webhook: Option, - respond_async: bool, + cancel_on_stream_drop: bool, ) -> Result<(PredictionHandle, UnregisteredPredictionSlot), CreatePredictionError> { let health = *self.health.read().await; if health != Health::Ready { @@ -487,7 +489,7 @@ impl PredictionService { prediction: prediction_arc, cancel_token: cancel_token.clone(), input, - respond_async, + cancel_on_stream_drop, }, ); @@ -521,15 +523,16 @@ impl PredictionService { ) -> Option { let entry = self.predictions.get(id)?; let stream = entry.prediction.lock().ok()?.subscribe_stream_replay(); - let respond_async = entry.respond_async; + let cancel_on_stream_drop = entry.cancel_on_stream_drop; Some(PredictionStreamSubscription { id: id.to_string(), replay: stream.replay, + skipped: stream.skipped, receiver: stream.receiver, guard: PredictionStreamGuard { id: id.to_string(), service: Arc::clone(self), - respond_async, + cancel_on_stream_drop, }, }) } @@ -626,6 +629,22 @@ impl PredictionService { ))); } + let was_cancelled_before_send = try_lock_prediction(&prediction_arc) + .map(|p| p.is_canceled()) + .unwrap_or(false); + if was_cancelled_before_send + && let Err(e) = state + .orchestrator + .cancel_by_prediction_id(&prediction_id) + .await + { + tracing::error!( + prediction_id = %prediction_id, + error = %e, + "Failed to forward pending cancellation after registration" + ); + } + // Wait for prediction to complete // Check if already terminal first to avoid race with fast completions let (already_terminal, completion) = { @@ -890,6 +909,58 @@ mod tests { } } + struct CancelRecordingOrchestrator { + cancel_count: AtomicUsize, + prediction: std::sync::Mutex>>>, + } + + impl CancelRecordingOrchestrator { + fn new() -> Self { + Self { + cancel_count: AtomicUsize::new(0), + prediction: std::sync::Mutex::new(None), + } + } + + fn cancel_count(&self) -> usize { + self.cancel_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl Orchestrator for CancelRecordingOrchestrator { + async fn register_prediction( + &self, + slot_id: SlotId, + prediction: Arc>, + idle_sender: tokio::sync::oneshot::Sender, + ) { + *self.prediction.lock().unwrap() = Some(prediction); + let _ = idle_sender.send(InactiveSlotIdleToken::new(slot_id).activate()); + } + + async fn cancel_by_prediction_id( + &self, + _prediction_id: &str, + ) -> Result<(), crate::orchestrator::OrchestratorError> { + self.cancel_count.fetch_add(1, Ordering::SeqCst); + if let Some(prediction) = self.prediction.lock().unwrap().as_ref() { + prediction.lock().unwrap().set_canceled(); + } + Ok(()) + } + + async fn healthcheck( + &self, + ) -> Result { + Ok(HealthcheckResult::healthy()) + } + + async fn shutdown(&self) -> Result<(), crate::orchestrator::OrchestratorError> { + Ok(()) + } + } + async fn create_test_pool(num_slots: usize) -> Arc { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::SlotRequest; @@ -1074,9 +1145,31 @@ mod tests { svc.set_orchestrator(pool, orchestrator).await; svc.set_health(Health::Ready).await; + let (_handle, _slot) = svc + .submit_prediction("sync-stream".to_string(), serde_json::json!({}), None, true) + .await + .unwrap(); + + let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + drop(subscription); + tokio::time::sleep(Duration::from_millis(25)).await; + + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + + #[tokio::test] + async fn dropping_async_json_stream_subscription_does_not_cancel_prediction() { + let svc = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CountingCancelOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + let (_handle, _slot) = svc .submit_prediction( - "sync-stream".to_string(), + "async-json-stream".to_string(), serde_json::json!({}), None, false, @@ -1084,15 +1177,17 @@ mod tests { .await .unwrap(); - let subscription = svc.subscribe_prediction_stream("sync-stream").unwrap(); + let subscription = svc + .subscribe_prediction_stream("async-json-stream") + .unwrap(); drop(subscription); tokio::time::sleep(Duration::from_millis(25)).await; - assert_eq!(orchestrator_ref.cancel_count(), 1); + assert_eq!(orchestrator_ref.cancel_count(), 0); } #[tokio::test] - async fn dropping_async_stream_subscription_does_not_cancel_prediction() { + async fn dropping_live_sse_stream_subscription_cancels_prediction() { let svc = Arc::new(PredictionService::new_no_pool()); let pool = create_test_pool(1).await; let orchestrator = Arc::new(CountingCancelOrchestrator::new()); @@ -1103,7 +1198,7 @@ mod tests { let (_handle, _slot) = svc .submit_prediction( - "async-stream".to_string(), + "live-sse-stream".to_string(), serde_json::json!({}), None, true, @@ -1111,11 +1206,11 @@ mod tests { .await .unwrap(); - let subscription = svc.subscribe_prediction_stream("async-stream").unwrap(); + let subscription = svc.subscribe_prediction_stream("live-sse-stream").unwrap(); drop(subscription); tokio::time::sleep(Duration::from_millis(25)).await; - assert_eq!(orchestrator_ref.cancel_count(), 0); + assert_eq!(orchestrator_ref.cancel_count(), 1); } #[tokio::test] @@ -1133,7 +1228,7 @@ mod tests { "completed-sync-stream".to_string(), serde_json::json!({}), None, - false, + true, ) .await .unwrap(); @@ -1208,6 +1303,42 @@ mod tests { assert_eq!(orch_ref.register_count(), 1); } + #[tokio::test] + async fn predict_forwards_cancel_token_set_before_registration() { + let svc = PredictionService::new_no_pool(); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(CancelRecordingOrchestrator::new()); + let orchestrator_ref = Arc::clone(&orchestrator); + + svc.set_orchestrator(pool, orchestrator).await; + svc.set_health(Health::Ready).await; + + let (handle, slot) = svc + .submit_prediction( + "pre-register-cancel".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + handle.cancel_token().cancel(); + + let result = tokio::time::timeout( + Duration::from_millis(100), + svc.predict( + slot, + serde_json::json!({}), + std::collections::HashMap::new(), + ), + ) + .await + .expect("prediction should observe cancellation after registration"); + + assert!(matches!(result, Err(PredictionError::Cancelled))); + assert_eq!(orchestrator_ref.cancel_count(), 1); + } + #[tokio::test] async fn health_shows_busy_when_all_slots_used() { let svc = PredictionService::new_no_pool(); diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 1cdc6f6c18..869f0307a6 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -430,7 +430,7 @@ async fn create_prediction_with_id( prediction_id.clone(), input.clone(), webhook_sender, - response_mode != PredictionResponseMode::SyncJson, + response_mode != PredictionResponseMode::AsyncJson, ) .await { @@ -462,6 +462,12 @@ async fn create_prediction_with_id( // Async mode: spawn background task, return immediately if response_mode != PredictionResponseMode::SyncJson { + let sse_subscription = if response_mode == PredictionResponseMode::AsyncSse { + service.subscribe_prediction_stream(&prediction_id) + } else { + None + }; + let service_clone = Arc::clone(&service); let id_for_cleanup = prediction_id.clone(); let context_async = context.clone(); @@ -475,7 +481,14 @@ async fn create_prediction_with_id( }); if response_mode == PredictionResponseMode::AsyncSse { - return stream_prediction_response(service, &prediction_id); + let Some(subscription) = sse_subscription else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Prediction not found"})), + ) + .into_response(); + }; + return stream_prediction_subscription_response(subscription); } return ( @@ -634,10 +647,11 @@ fn stream_event_to_sse(event: crate::prediction::PredictionStreamEvent) -> Event fn prediction_sse_stream( subscription: PredictionStreamSubscription, ) -> impl futures::Stream> { - let (replay, receiver, guard) = subscription.into_parts(); + let (replay, replay_skipped, receiver, guard) = subscription.into_parts(); struct StreamState { replay: std::collections::VecDeque, + replay_skipped: u64, receiver: tokio::sync::broadcast::Receiver, _guard: crate::service::PredictionStreamGuard, done: bool, @@ -646,6 +660,7 @@ fn prediction_sse_stream( futures::stream::unfold( StreamState { replay: replay.into(), + replay_skipped, receiver, _guard: guard, done: false, @@ -655,23 +670,44 @@ fn prediction_sse_stream( return None; } + if state.replay_skipped > 0 { + let skipped = state.replay_skipped; + state.replay_skipped = 0; + state.done = true; + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream replay truncated; events were dropped", + "skipped": skipped, + })) + .expect("SSE replay truncation error serializes to JSON"); + return Some((Ok(event), state)); + } + if let Some(event) = state.replay.pop_front() { state.done = event.event_name() == "completed"; return Some((Ok(stream_event_to_sse(event)), state)); } - loop { - match state.receiver.recv().await { - Ok(event) => { - state.done = event.event_name() == "completed"; - return Some((Ok(stream_event_to_sse(event)), state)); - } - Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { - tracing::warn!(skipped, "SSE prediction stream receiver lagged"); - continue; - } - Err(tokio::sync::broadcast::error::RecvError::Closed) => return None, + match state.receiver.recv().await { + Ok(event) => { + state.done = event.event_name() == "completed"; + Some((Ok(stream_event_to_sse(event)), state)) + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => { + tracing::warn!(skipped, "SSE prediction stream receiver lagged"); + state.done = true; + // In the future, this could become backpressure or cursor-based replay. + let event = Event::default() + .event("error") + .json_data(serde_json::json!({ + "error": "SSE stream lagged; events were dropped", + "skipped": skipped, + })) + .expect("SSE lag error serializes to JSON"); + Some((Ok(event), state)) } + Err(tokio::sync::broadcast::error::RecvError::Closed) => None, } }, ) @@ -686,6 +722,10 @@ fn stream_prediction_response(service: Arc, prediction_id: &s .into_response(); }; + stream_prediction_subscription_response(subscription) +} + +fn stream_prediction_subscription_response(subscription: PredictionStreamSubscription) -> Response { Sse::new(prediction_sse_stream(subscription)) .keep_alive( KeepAlive::new() @@ -1120,6 +1160,97 @@ mod tests { "unexpected content-type: {:?}", content_type ); + + let body = response.into_body(); + let bytes = body.collect().await.unwrap().to_bytes(); + let sse = String::from_utf8(bytes.to_vec()).unwrap(); + assert!(sse.contains("event: completed"), "SSE body: {sse}"); + assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}"); + } + + #[tokio::test] + async fn lagged_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "lagged-stream".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + let subscription = service + .subscribe_prediction_stream("lagged-stream") + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("lagged SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!(sse.contains("SSE stream lagged"), "SSE body: {sse}"); + assert!(sse.contains("skipped"), "SSE body: {sse}"); + } + + #[tokio::test] + async fn truncated_replay_prediction_sse_stream_emits_error_and_closes() { + let service = Arc::new(PredictionService::new_no_pool()); + let pool = create_test_pool(1).await; + let orchestrator = Arc::new(MockOrchestrator::never_complete()); + service.set_orchestrator(pool, orchestrator).await; + service.set_health(Health::Ready).await; + + let (_handle, slot) = service + .submit_prediction( + "truncated-replay".to_string(), + serde_json::json!({}), + None, + true, + ) + .await + .unwrap(); + + { + let prediction = slot.prediction(); + let mut prediction = prediction.lock().unwrap(); + for index in 0..1030 { + prediction.append_output_chunk(serde_json::json!(index), index); + } + } + + let subscription = service + .subscribe_prediction_stream("truncated-replay") + .unwrap(); + let response = Sse::new(prediction_sse_stream(subscription)).into_response(); + let collected = + tokio::time::timeout(Duration::from_millis(100), response.into_body().collect()) + .await + .expect("truncated replay SSE stream should close after emitting an error") + .unwrap(); + let sse = String::from_utf8(collected.to_bytes().to_vec()).unwrap(); + assert!(sse.contains("event: error"), "SSE body: {sse}"); + assert!( + sse.contains("SSE stream replay truncated"), + "SSE body: {sse}" + ); + assert!(sse.contains("skipped"), "SSE body: {sse}"); } #[tokio::test] diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 5f567cfd6e..32c757008f 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -2,10 +2,9 @@ [short] skip 'requires Docker build' -cog build -t $TEST_IMAGE cog serve --upload-url http://unused/ -curl -N -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' stdout 'event: output' stdout 'data: {"chunk":"chunk-1","index":0}' stdout 'event: output' From ff3140eb437f87d0a152d8752f9a4fc777c95ceb Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 18 May 2026 12:05:00 -0400 Subject: [PATCH 6/8] fix: address SSE review feedback --- crates/coglet/src/prediction.rs | 22 +++++++++---------- crates/coglet/src/service.rs | 3 +++ crates/coglet/src/transport/http/routes.rs | 2 +- .../tests/sse_streaming_output.txtar | 1 + 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 10c87fe83e..5c3419dc9a 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -1,6 +1,6 @@ //! Prediction state tracking. -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; use std::time::Instant; @@ -10,7 +10,7 @@ pub use tokio_util::sync::CancellationToken; use crate::bridge::protocol::{LogSource, MetricMode}; use crate::webhook::{WebhookEventType, WebhookSender}; -const MAX_STREAM_HISTORY_EVENTS: usize = 1024; +const STREAM_EVENT_BUFFER_CAPACITY: usize = 1024; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PredictionStatus { @@ -144,7 +144,7 @@ pub struct Prediction { webhook: Option, completion: Arc, stream_tx: tokio::sync::broadcast::Sender, - stream_history: Vec, + stream_history: VecDeque, stream_history_skipped: u64, /// User-emitted metrics. Merged with system metrics (predict_time) in terminal response. metrics: HashMap, @@ -152,7 +152,7 @@ pub struct Prediction { impl Prediction { pub fn new(id: String, webhook: Option) -> Self { - let (stream_tx, _) = tokio::sync::broadcast::channel(1024); + let (stream_tx, _) = tokio::sync::broadcast::channel(STREAM_EVENT_BUFFER_CAPACITY); Self { id, @@ -166,7 +166,7 @@ impl Prediction { webhook, completion: Arc::new(Notify::new()), stream_tx, - stream_history: Vec::new(), + stream_history: VecDeque::new(), stream_history_skipped: 0, metrics: HashMap::new(), } @@ -186,7 +186,7 @@ impl Prediction { pub fn subscribe_stream_replay(&self) -> PredictionStreamReplay { PredictionStreamReplay { - replay: self.stream_history.clone(), + replay: self.stream_history.iter().cloned().collect(), skipped: self.stream_history_skipped, receiver: self.stream_tx.subscribe(), } @@ -197,11 +197,11 @@ impl Prediction { } fn emit_stream_event(&mut self, event: PredictionStreamEvent) { - if self.stream_history.len() == MAX_STREAM_HISTORY_EVENTS { - self.stream_history.remove(0); + if self.stream_history.len() == STREAM_EVENT_BUFFER_CAPACITY { + self.stream_history.pop_front(); self.stream_history_skipped += 1; } - self.stream_history.push(event.clone()); + self.stream_history.push_back(event.clone()); let _ = self.stream_tx.send(event); } @@ -734,14 +734,14 @@ mod tests { let replay = prediction.subscribe_stream_replay(); - assert_eq!(replay.replay.len(), MAX_STREAM_HISTORY_EVENTS); + assert_eq!(replay.replay.len(), STREAM_EVENT_BUFFER_CAPACITY); assert_eq!(replay.skipped, 77); assert_eq!( replay.replay[0].json_data(), serde_json::json!({"chunk":76,"index":76}) ); assert_eq!( - replay.replay[MAX_STREAM_HISTORY_EVENTS - 1].json_data(), + replay.replay[STREAM_EVENT_BUFFER_CAPACITY - 1].json_data(), serde_json::json!({"chunk":1099,"index":1099}) ); } diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 30a1c85163..d7a85726cb 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -144,6 +144,9 @@ impl Drop for PredictionStreamGuard { return; } + // Prediction cleanup may remove the service entry before the SSE response + // finishes draining. Missing entries deliberately report zero receivers and + // terminal state so this guard cannot cancel an already-cleaned prediction. if self.service.stream_receiver_count(&self.id) == 0 && !self.service.prediction_is_terminal(&self.id) { diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index 869f0307a6..d5bf7e0064 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -430,7 +430,7 @@ async fn create_prediction_with_id( prediction_id.clone(), input.clone(), webhook_sender, - response_mode != PredictionResponseMode::AsyncJson, + response_mode == PredictionResponseMode::AsyncSse, ) .await { diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 32c757008f..4d3f2102c9 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -5,6 +5,7 @@ cog serve --upload-url http://unused/ curl -H Accept:text/event-stream PUT /predictions/sse-stream-test '{"id":"sse-stream-test","input":{}}' +stdout 'event: start' stdout 'event: output' stdout 'data: {"chunk":"chunk-1","index":0}' stdout 'event: output' From 8c9c9826e8a21628fe4200da23e96e63dfb0ac46 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 18 May 2026 16:05:13 -0400 Subject: [PATCH 7/8] feat: make prediction streaming opt-in --- architecture/02-schema.md | 73 ++++++++++--------- crates/coglet/src/service.rs | 19 +++++ crates/coglet/src/transport/http/routes.rs | 56 ++++++++++++++ docs/llms.txt | 18 +++-- docs/python.md | 18 +++-- .../tests/sse_requires_streaming_opt_in.txtar | 27 +++++++ .../tests/sse_streaming_output.txtar | 3 +- pkg/schema/openapi.go | 48 ++++++------ pkg/schema/openapi_test.go | 28 +++++++ pkg/schema/python/parser.go | 51 +++++++++++-- pkg/schema/python/parser_test.go | 71 ++++++++++++++++++ pkg/schema/types.go | 7 +- python/cog/__init__.py | 11 +++ python/tests/test_types.py | 14 ++++ 14 files changed, 366 insertions(+), 78 deletions(-) create mode 100644 integration-tests/tests/sse_requires_streaming_opt_in.txtar diff --git a/architecture/02-schema.md b/architecture/02-schema.md index bd217199f3..26bc99cbdf 100644 --- a/architecture/02-schema.md +++ b/architecture/02-schema.md @@ -186,40 +186,43 @@ Each `SchemaType` produces its JSON Schema fragment via `JSONSchema()`: ### Output Types -| Python | SchemaType | JSON Schema | -| -------------------------- | ------------------------ | --------------------------------------------------------------- | -| `str` | `SchemaPrimitive` | `{"type": "string"}` | -| `int` | `SchemaPrimitive` | `{"type": "integer"}` | -| `float` | `SchemaPrimitive` | `{"type": "number"}` | -| `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | -| `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | -| `dict` (bare) | `SchemaAny` | `{"type": "object"}` | -| `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | -| `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | -| `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | -| `Annotated[T, cog.Opaque]` | `SchemaPrimitive(TypeAny)` | `{"type": "object"}` | +| Python | SchemaType | JSON Schema | +| -------------------------------- | --------------------------------------- | --------------------------------------------------------------- | +| `str` | `SchemaPrimitive` | `{"type": "string"}` | +| `int` | `SchemaPrimitive` | `{"type": "integer"}` | +| `float` | `SchemaPrimitive` | `{"type": "number"}` | +| `bool` | `SchemaPrimitive` | `{"type": "boolean"}` | +| `Path` | `SchemaPrimitive` | `{"type": "string", "format": "uri"}` | +| `dict` (bare) | `SchemaAny` | `{"type": "object"}` | +| `dict[str, V]` | `SchemaDict` | `{"type": "object", "additionalProperties": V}` | +| `list` (bare) | `SchemaArray(SchemaAny)` | `{"type": "array", "items": {"type": "object"}}` | +| `list[T]` | `SchemaArray` | `{"type": "array", "items": T}` | +| `Annotated[T, cog.Opaque]` | `SchemaPrimitive(TypeAny)` | `{"type": "object"}` | | `Annotated[list[T], cog.Opaque]` | `SchemaArray(SchemaPrimitive(TypeAny))` | `{"type": "array", "items": {"type": "object"}}` | -| `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | -| `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | -| `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | -| Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | +| `BaseModel` subclass | `SchemaObject` | `{"type": "object", "properties": {...}}` | +| `Iterator[T]` | `SchemaIterator` | `{"type": "array", "items": T, "x-cog-array-type": "iterator"}` | +| `ConcatenateIterator[str]` | `SchemaConcatIterator` | Streaming token output | +| Nested types | Recursive | `dict[str, list[dict[str, int]]]` fully supported | ### Unsupported Output Types -| Python | Error | -| --------------------------- | -------------------------------------------------------------------- | -| `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | -| `Union[A, B]` | Ambiguous for downstream consumers | +| Python | Error | +| --------------------------- | -------------------------------------------------------------------------------------------------------------------------------- | +| `Optional[T]` / `T \| None` | Predictions must succeed with a value or fail with an error | +| `Union[A, B]` | Ambiguous for downstream consumers | | External package types | Cannot be statically analyzed — define as BaseModel, use .pyi stub, or mark JSON-shaped values with `Annotated[..., cog.Opaque]` | ## Cog-Specific Extensions -| Extension | Purpose | -| --------------------- | ------------------------------------------------- | -| `x-order` | Preserves parameter order from function signature | -| `x-cog-array-type` | Marks iterators vs regular arrays | -| `x-cog-array-display` | Hints for how to display streaming output | -| `x-cog-secret` | Marks sensitive inputs | +| Extension | Purpose | +| --------------------- | --------------------------------------------------- | +| `x-order` | Preserves parameter order from function signature | +| `x-cog-array-type` | Marks iterators vs regular arrays | +| `x-cog-array-display` | Hints for how to display streaming output | +| `x-cog-secret` | Marks sensitive inputs | +| `x-cog-streaming` | Marks prediction operations that accept SSE clients | + +Iterator output types describe the shape of accumulated JSON output. SSE response support is a separate prediction operation capability and is only advertised when the prediction handler opts in with `@cog.streaming`. ## Where the Schema Lives @@ -311,12 +314,12 @@ A simplified example showing a multi-file predictor with structured output: ## Code References -| File | Purpose | -| ----------------------------- | -------------------------------------------------------------------- | -| `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | -| `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, imports | -| `pkg/schema/python/` | Tree-sitter Python parser and cross-file resolution | -| `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | -| `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | -| `pkg/schema/errors.go` | Typed schema error kinds | -| `pkg/image/build.go` | Build-time schema generation entry point and schema file validation | +| File | Purpose | +| --------------------------- | -------------------------------------------------------------------- | +| `pkg/schema/schema_type.go` | `SchemaType` ADT, `ResolveSchemaType()`, `JSONSchema()` generation | +| `pkg/schema/types.go` | `PredictorInfo`, `PrimitiveType`, `FieldType`, `InputField`, imports | +| `pkg/schema/python/` | Tree-sitter Python parser and cross-file resolution | +| `pkg/schema/openapi.go` | OpenAPI document assembly from `PredictorInfo` | +| `pkg/schema/generator.go` | Top-level `Generate()`, `GenerateCombined()`, `Parser` type | +| `pkg/schema/errors.go` | Typed schema error kinds | +| `pkg/image/build.go` | Build-time schema generation entry point and schema file validation | diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index d7a85726cb..bd33d9e5b6 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -226,6 +226,7 @@ pub struct PredictionService { schema: RwLock>, input_validator: RwLock>, train_validator: RwLock>, + supports_prediction_streaming: RwLock, } impl PredictionService { @@ -245,6 +246,7 @@ impl PredictionService { schema: RwLock::new(None), input_validator: RwLock::new(None), train_validator: RwLock::new(None), + supports_prediction_streaming: RwLock::new(false), } } @@ -299,6 +301,10 @@ impl PredictionService { self.train_validator.read().await.is_some() } + pub async fn supports_prediction_streaming(&self) -> bool { + *self.supports_prediction_streaming.read().await + } + /// Get the permit pool from orchestrator. pub async fn pool(&self) -> Option> { if let Some(ref state) = *self.orchestrator.read().await { @@ -359,6 +365,9 @@ impl PredictionService { } pub async fn set_schema(&self, schema: serde_json::Value) { + let supports_prediction_streaming = Self::schema_supports_prediction_streaming(&schema); + *self.supports_prediction_streaming.write().await = supports_prediction_streaming; + // Compile input validators from the schema components let validator = InputValidator::from_openapi_schema(&schema); if let Some(v) = &validator { @@ -382,6 +391,16 @@ impl PredictionService { *self.schema.write().await = Some(schema); } + fn schema_supports_prediction_streaming(schema: &serde_json::Value) -> bool { + schema + .get("paths") + .and_then(|paths| paths.get("/predictions")) + .and_then(|path| path.get("post")) + .and_then(|operation| operation.get("x-cog-streaming")) + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + } + pub async fn schema(&self) -> Option { self.schema.read().await.clone() } diff --git a/crates/coglet/src/transport/http/routes.rs b/crates/coglet/src/transport/http/routes.rs index d5bf7e0064..6822fdead3 100644 --- a/crates/coglet/src/transport/http/routes.rs +++ b/crates/coglet/src/transport/http/routes.rs @@ -245,6 +245,16 @@ fn prediction_response_mode(headers: &HeaderMap) -> PredictionResponseMode { } } +fn streaming_not_supported_response() -> Response { + ( + StatusCode::NOT_ACCEPTABLE, + Json(serde_json::json!({ + "error": "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE." + })), + ) + .into_response() +} + fn json_response_mode(headers: &HeaderMap) -> PredictionResponseMode { if should_respond_async(headers) { PredictionResponseMode::AsyncJson @@ -330,6 +340,9 @@ async fn create_prediction_idempotent( // Check if prediction with this ID is already in-flight if let Some(response) = service.get_prediction_response(&prediction_id) { if response_mode == PredictionResponseMode::AsyncSse { + if !service.supports_prediction_streaming().await { + return streaming_not_supported_response(); + } return stream_prediction_response(service, &prediction_id); } return (StatusCode::ACCEPTED, Json(response)).into_response(); @@ -386,6 +399,13 @@ async fn create_prediction_with_id( trace_context: TraceContext, is_training: bool, ) -> Response { + if !is_training + && response_mode == PredictionResponseMode::AsyncSse + && !service.supports_prediction_streaming().await + { + return streaming_not_supported_response(); + } + // Strip unknown fields and validate in one pass. Unknown inputs are // silently dropped to match Replicate's historical API behavior. let (stripped, validation_result) = if is_training { @@ -1076,6 +1096,15 @@ mod tests { service } + async fn enable_prediction_streaming(service: &PredictionService) { + service + .set_schema(serde_json::json!({ + "paths": {"/predictions": {"post": {"x-cog-streaming": true}}}, + "components": {"schemas": {"Input": {"type": "object", "properties": {}}}} + })) + .await; + } + #[tokio::test] async fn health_check_ready_with_orchestrator() { let service = create_ready_service().await; @@ -1137,6 +1166,7 @@ mod tests { #[tokio::test] async fn prediction_post_with_sse_accept_returns_sse() { let service = create_ready_service().await; + enable_prediction_streaming(&service).await; let app = routes(service); let response = app @@ -1168,6 +1198,30 @@ mod tests { assert!(sse.contains(r#""status":"succeeded""#), "SSE body: {sse}"); } + #[tokio::test] + async fn prediction_post_with_sse_accept_rejects_when_not_opted_in() { + let service = create_ready_service().await; + let app = routes(service); + + let response = app + .oneshot( + Request::post("/predictions") + .header("content-type", "application/json") + .header("accept", "text/event-stream") + .body(Body::from(r#"{"input":{}}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::NOT_ACCEPTABLE); + let json = response_json(response).await; + assert_eq!( + json["error"], + "This model does not support streaming responses. Add @cog.streaming to predict() to enable SSE." + ); + } + #[tokio::test] async fn lagged_prediction_sse_stream_emits_error_and_closes() { let service = Arc::new(PredictionService::new_no_pool()); @@ -1256,6 +1310,7 @@ mod tests { #[tokio::test] async fn prediction_put_with_sse_accept_returns_sse() { let service = create_ready_service().await; + enable_prediction_streaming(&service).await; let app = routes(service); let response = app @@ -1284,6 +1339,7 @@ mod tests { #[tokio::test] async fn prediction_put_existing_with_sse_accept_returns_sse() { let service = create_ready_service().await; + enable_prediction_streaming(&service).await; let (_handle, _slot) = service .submit_prediction( "existing-sse-put".to_string(), diff --git a/docs/llms.txt b/docs/llms.txt index b38aa99586..ead330f0c2 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -1975,13 +1975,17 @@ class Predictor(BasePredictor): Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. -To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To define streaming-shaped output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. + +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the `predict()` method with `@cog.streaming` or `@streaming` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py -from cog import BasePredictor, Path from typing import Iterator +from cog import BasePredictor, Path, streaming + class Predictor(BasePredictor): + @streaming def predict(self) -> Iterator[Path]: done = False while not done: @@ -1993,9 +1997,11 @@ If you have an [async `predict()` method](#async-predictors-and-concurrency), us ```py from typing import AsyncIterator -from cog import BasePredictor, Path + +from cog import BasePredictor, Path, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncIterator[Path]: done = False while not done: @@ -2006,9 +2012,10 @@ class Predictor(BasePredictor): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BasePredictor, Path, ConcatenateIterator +from cog import BasePredictor, ConcatenateIterator, streaming class Predictor(BasePredictor): + @streaming def predict(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: @@ -2018,9 +2025,10 @@ class Predictor(BasePredictor): Or for async `predict()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BasePredictor, Path, AsyncConcatenateIterator +from cog import AsyncConcatenateIterator, BasePredictor, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: diff --git a/docs/python.md b/docs/python.md index 06d2a8adc3..76c294728d 100644 --- a/docs/python.md +++ b/docs/python.md @@ -259,13 +259,17 @@ class Predictor(BasePredictor): Cog models can stream output as the `predict()` method is running. For example, a language model can output tokens as they're being generated and an image generation model can output images as they are being generated. -To support streaming output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. +To define streaming-shaped output in your Cog model, add `from typing import Iterator` to your predict.py file. The `typing` package is a part of Python's standard library so it doesn't need to be installed. Then add a return type annotation to the `predict()` method in the form `-> Iterator[]` where `` can be one of `str`, `int`, `float`, `bool`, or `cog.Path`. + +To allow clients to receive chunks as server-sent events with `Accept: text/event-stream`, decorate the `predict()` method with `@cog.streaming` or `@streaming` imported from `cog`. Without the decorator, iterator outputs still work in normal JSON responses, but SSE requests return `406 Not Acceptable`. ```py -from cog import BasePredictor, Path from typing import Iterator +from cog import BasePredictor, Path, streaming + class Predictor(BasePredictor): + @streaming def predict(self) -> Iterator[Path]: done = False while not done: @@ -277,9 +281,11 @@ If you have an [async `predict()` method](#async-predictors-and-concurrency), us ```py from typing import AsyncIterator -from cog import BasePredictor, Path + +from cog import BasePredictor, Path, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncIterator[Path]: done = False while not done: @@ -290,9 +296,10 @@ class Predictor(BasePredictor): If you're streaming text output, you can use `ConcatenateIterator` to hint that the output should be concatenated together into a single string. This is useful on Replicate to display the output as a string instead of a list of strings. ```py -from cog import BasePredictor, Path, ConcatenateIterator +from cog import BasePredictor, ConcatenateIterator, streaming class Predictor(BasePredictor): + @streaming def predict(self) -> ConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: @@ -302,9 +309,10 @@ class Predictor(BasePredictor): Or for async `predict()` methods, use `AsyncConcatenateIterator`: ```py -from cog import BasePredictor, Path, AsyncConcatenateIterator +from cog import AsyncConcatenateIterator, BasePredictor, streaming class Predictor(BasePredictor): + @streaming async def predict(self) -> AsyncConcatenateIterator[str]: tokens = ["The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] for token in tokens: diff --git a/integration-tests/tests/sse_requires_streaming_opt_in.txtar b/integration-tests/tests/sse_requires_streaming_opt_in.txtar new file mode 100644 index 0000000000..7edf04fab1 --- /dev/null +++ b/integration-tests/tests/sse_requires_streaming_opt_in.txtar @@ -0,0 +1,27 @@ +# Test that SSE requires @cog.streaming while undecorated iterators still work as JSON. + +[short] skip 'requires Docker build' + +cog serve --upload-url http://unused/ + +cog predict -i count=2 +stdout '"output":\["chunk-0","chunk-1"\]' + +! curl -H Accept:text/event-stream PUT /predictions/no-streaming '{"id":"no-streaming","input":{"count":2}}' +stderr 'This model does not support streaming responses' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from typing import Iterator + +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self, count: int) -> Iterator[str]: + for index in range(count): + yield f"chunk-{index}" diff --git a/integration-tests/tests/sse_streaming_output.txtar b/integration-tests/tests/sse_streaming_output.txtar index 4d3f2102c9..23646106f5 100644 --- a/integration-tests/tests/sse_streaming_output.txtar +++ b/integration-tests/tests/sse_streaming_output.txtar @@ -22,10 +22,11 @@ predict: "predict.py:Predictor" import time from typing import Iterator -from cog import BasePredictor +from cog import BasePredictor, streaming class Predictor(BasePredictor): + @streaming def predict(self) -> Iterator[str]: time.sleep(0.25) yield "chunk-1" diff --git a/pkg/schema/openapi.go b/pkg/schema/openapi.go index 9435ea8248..f15f4471ca 100644 --- a/pkg/schema/openapi.go +++ b/pkg/schema/openapi.go @@ -187,38 +187,40 @@ func buildOpenAPISpec(info *PredictorInfo) map[string]any { }) // Main endpoint (predict or train) - paths.Set(endpoint, map[string]any{ - "post": map[string]any{ - "summary": summary, - "description": description, - "operationId": opID, - "requestBody": map[string]any{ + mainOperation := map[string]any{ + "summary": summary, + "description": description, + "operationId": opID, + "requestBody": map[string]any{ + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{"$ref": requestRef}, + }, + }, + }, + "responses": map[string]any{ + "200": map[string]any{ + "description": "Successful Response", "content": map[string]any{ "application/json": map[string]any{ - "schema": map[string]any{"$ref": requestRef}, + "schema": map[string]any{"$ref": responseRef}, }, }, }, - "responses": map[string]any{ - "200": map[string]any{ - "description": "Successful Response", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": map[string]any{"$ref": responseRef}, - }, - }, - }, - "422": map[string]any{ - "description": "Validation Error", - "content": map[string]any{ - "application/json": map[string]any{ - "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, - }, + "422": map[string]any{ + "description": "Validation Error", + "content": map[string]any{ + "application/json": map[string]any{ + "schema": map[string]any{"$ref": "#/components/schemas/HTTPValidationError"}, }, }, }, }, - }) + } + if !isTrain && info.SupportsStreaming { + mainOperation["x-cog-streaming"] = true + } + paths.Set(endpoint, map[string]any{"post": mainOperation}) // Cancel endpoint paths.Set(cancelEP, map[string]any{ diff --git a/pkg/schema/openapi_test.go b/pkg/schema/openapi_test.go index a3a9f7a8a6..9cc037c684 100644 --- a/pkg/schema/openapi_test.go +++ b/pkg/schema/openapi_test.go @@ -657,6 +657,34 @@ func TestOutputConcatenateIterator(t *testing.T) { assert.Equal(t, "concatenate", output["x-cog-array-display"]) } +func TestPredictionOperationIncludesStreamingExtensionWhenEnabled(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModePredict, + SupportsStreaming: true, + } + + spec := parseSpec(t, info) + post := getPath(spec, "paths", "/predictions", "post").(map[string]any) + assert.Equal(t, true, post["x-cog-streaming"]) +} + +func TestPredictionOperationOmitsStreamingExtensionByDefault(t *testing.T) { + inputs := NewOrderedMap[string, InputField]() + info := &PredictorInfo{ + Inputs: inputs, + Output: SchemaIteratorOf(SchemaPrim(TypeString)), + Mode: ModePredict, + } + + spec := parseSpec(t, info) + post := getPath(spec, "paths", "/predictions", "post").(map[string]any) + _, ok := post["x-cog-streaming"] + assert.False(t, ok) +} + func TestOutputObject(t *testing.T) { inputs := NewOrderedMap[string, InputField]() fields := NewOrderedMap[string, SchemaField]() diff --git a/pkg/schema/python/parser.go b/pkg/schema/python/parser.go index c711ea0960..3fbac6bf65 100644 --- a/pkg/schema/python/parser.go +++ b/pkg/schema/python/parser.go @@ -60,10 +60,15 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi methodName = "train" } - funcNode, err := findTargetFunction(root, source, predictRef, methodName) + targetNode, err := findTargetFunction(root, source, predictRef, methodName) if err != nil { return nil, err } + supportsStreaming := functionSupportsStreaming(targetNode, source, imports) + funcNode := UnwrapFunction(targetNode) + if funcNode == nil { + return nil, schema.WrapError(schema.ErrParse, "target is not a function", nil) + } // 6. Check if method (has self first param) paramsNode := funcNode.ChildByFieldName("parameters") @@ -100,9 +105,10 @@ func ParsePredictor(source []byte, predictRef string, mode schema.Mode, sourceDi } return &schema.PredictorInfo{ - Inputs: inputs, - Output: output, - Mode: mode, + Inputs: inputs, + Output: output, + Mode: mode, + SupportsStreaming: supportsStreaming, }, nil } @@ -651,6 +657,39 @@ func UnwrapFunction(node *sitter.Node) *sitter.Node { return nil } +func functionSupportsStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + if node.Type() != "decorated_definition" { + return false + } + for _, child := range NamedChildren(node) { + if child.Type() != "decorator" { + continue + } + if decoratorIsCogStreaming(child, source, imports) { + return true + } + } + return false +} + +func decoratorIsCogStreaming(node *sitter.Node, source []byte, imports *schema.ImportContext) bool { + for _, child := range NamedChildren(node) { + switch child.Type() { + case "attribute": + return Content(child, source) == "cog.streaming" + case "identifier": + if Content(child, source) != "streaming" { + return false + } + entry, ok := imports.Names.Get("streaming") + return ok && entry.Module == "cog" && entry.Original == "streaming" + case "call": + return false + } + } + return false +} + func InheritsFromBaseModel(classNode *sitter.Node, source []byte, imports *schema.ImportContext) bool { supers := classNode.ChildByFieldName("superclasses") if supers == nil { @@ -1205,7 +1244,7 @@ func findTargetFunction(root *sitter.Node, source []byte, predictRef, methodName if nameNode != nil { name := Content(nameNode, source) if name == predictRef || name == methodName { - return funcNode, nil + return child, nil } } } @@ -1226,7 +1265,7 @@ func findMethodInClass(classNode *sitter.Node, source []byte, className, methodN } nameNode := funcNode.ChildByFieldName("name") if nameNode != nil && Content(nameNode, source) == methodName { - return funcNode, nil + return child, nil } } diff --git a/pkg/schema/python/parser_test.go b/pkg/schema/python/parser_test.go index 2b3d59445a..f82058e36a 100644 --- a/pkg/schema/python/parser_test.go +++ b/pkg/schema/python/parser_test.go @@ -455,6 +455,77 @@ class Predictor(BasePredictor): require.Equal(t, schema.ErrConcatIteratorNotStr, se.Kind) } +func TestStreamingDecoratorQualifiedOptIn(t *testing.T) { + source := ` +import cog +from typing import Iterator + +class Predictor(cog.BasePredictor): + @cog.streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorImportedOptIn(t *testing.T) { + source := ` +from cog import BasePredictor, streaming +from typing import Iterator + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.True(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorIgnoredWhenNotFromCog(t *testing.T) { + source := ` +from other import streaming +from typing import Iterator +from cog import BasePredictor + +class Predictor(BasePredictor): + @streaming + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorParameterizedFormIgnored(t *testing.T) { + source := ` +import cog +from typing import Iterator + +class Predictor(cog.BasePredictor): + @cog.streaming() + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + +func TestStreamingDecoratorClassLevelIgnored(t *testing.T) { + source := ` +import cog +from typing import Iterator + +@cog.streaming +class Predictor(cog.BasePredictor): + def predict(self) -> Iterator[str]: + yield "hello" +` + info := parse(t, source, "Predictor") + require.False(t, info.SupportsStreaming) +} + func TestListOutput(t *testing.T) { source := ` from cog import BasePredictor, Path diff --git a/pkg/schema/types.go b/pkg/schema/types.go index 5a7c7ce611..1b0250e8f7 100644 --- a/pkg/schema/types.go +++ b/pkg/schema/types.go @@ -197,9 +197,10 @@ func (f *InputField) IsRequired() bool { // PredictorInfo is the top-level extraction result. type PredictorInfo struct { - Inputs *OrderedMap[string, InputField] - Output SchemaType - Mode Mode + Inputs *OrderedMap[string, InputField] + Output SchemaType + Mode Mode + SupportsStreaming bool } // TypeAnnotation is a parsed Python type annotation (intermediate, before resolution). diff --git a/python/cog/__init__.py b/python/cog/__init__.py index 6c1fb8ce44..f7f4e1d461 100644 --- a/python/cog/__init__.py +++ b/python/cog/__init__.py @@ -20,6 +20,8 @@ def predict( """ import sys as _sys +from collections.abc import Callable +from typing import TypeVar from coglet import CancelationException as CancelationException @@ -38,6 +40,14 @@ def predict( URLPath, ) +F = TypeVar("F", bound=Callable[..., object]) + + +def streaming(fn: F) -> F: + """Mark a predict handler as supporting streaming responses.""" + fn.__cog_streaming__ = True # type: ignore[attr-defined] + return fn + # --------------------------------------------------------------------------- # Backwards-compatibility shim: ExperimentalFeatureWarning @@ -133,6 +143,7 @@ def current_scope() -> object: "CancelationException", # Metrics "current_scope", + "streaming", # Deprecated compat shims "ExperimentalFeatureWarning", "emit_metric", diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 653658ce11..6b4aead4d7 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -10,6 +10,7 @@ Path, Secret, URLFile, + streaming, ) @@ -132,3 +133,16 @@ def test_async_concatenate_iterator_is_abstract(self) -> None: from typing import AsyncIterator assert issubclass(AsyncConcatenateIterator, AsyncIterator) + + +class TestStreamingDecorator: + """Tests for the streaming opt-in decorator.""" + + def test_streaming_marks_function_and_returns_same_object(self) -> None: + def predict() -> str: + return "ok" + + decorated = streaming(predict) + + assert decorated is predict + assert predict.__cog_streaming__ is True # type: ignore[attr-defined] From 38665ac969e8bfd05fe3c66cc69d4d1a3795abf5 Mon Sep 17 00:00:00 2001 From: Mark Phelps Date: Mon, 18 May 2026 16:19:26 -0400 Subject: [PATCH 8/8] fix: match iterator CLI output in SSE opt-in test --- integration-tests/tests/sse_requires_streaming_opt_in.txtar | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/integration-tests/tests/sse_requires_streaming_opt_in.txtar b/integration-tests/tests/sse_requires_streaming_opt_in.txtar index 7edf04fab1..7ac62c51a0 100644 --- a/integration-tests/tests/sse_requires_streaming_opt_in.txtar +++ b/integration-tests/tests/sse_requires_streaming_opt_in.txtar @@ -5,7 +5,8 @@ cog serve --upload-url http://unused/ cog predict -i count=2 -stdout '"output":\["chunk-0","chunk-1"\]' +stdout '"chunk-0"' +stdout '"chunk-1"' ! curl -H Accept:text/event-stream PUT /predictions/no-streaming '{"id":"no-streaming","input":{"count":2}}' stderr 'This model does not support streaming responses'